diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 94919ccf2..62b4ce36b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -131,8 +131,21 @@ jobs: tests/test_mllm_cache.py \ tests/test_optimizations.py \ tests/test_simple_engine.py \ + tests/test_chat_template_kwargs.py \ tests/test_batching.py \ tests/test_continuous_batching.py \ + tests/test_memory_cache_mlx.py \ + -v --tb=short \ + -m "not slow" \ + -k "not Integration" + + - name: Run EngineCore stream-affinity regression tests + run: | + # Fresh process so globals like mlx_lm.generate.generation_stream + # are not rebound by earlier tests (see issue #407). + pytest \ + tests/test_batching_deterministic.py \ + tests/test_engine_core_stream_safety.py \ -v --tb=short \ -m "not slow" \ -k "not Integration" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e2bc037a..8c8bda79f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,8 +13,14 @@ repos: rev: v0.1.9 hooks: - id: ruff - args: [--fix, --exit-non-zero-on-fix] - - id: ruff-format + args: [--select, E,F,W, --ignore, E402,E501,E731,F811,F841] + + - repo: https://github.com/psf/black + rev: 24.1.1 + hooks: + - id: black + args: [--check] + files: ^(vllm_mlx|tests)/ - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.8.0 diff --git a/README.es.md b/README.es.md new file mode 100644 index 000000000..37b20a4b6 --- /dev/null +++ b/README.es.md @@ -0,0 +1,302 @@ +# vllm-mlx + +**Lee esto en otros idiomas:** [English](README.md) · [Español](README.es.md) · [Français](README.fr.md) · [中文](README.zh.md) + +**Continuous batching + APIs OpenAI y Anthropic en un solo servidor. Inferencia nativa en Apple Silicon.** + +[![PyPI version](https://img.shields.io/pypi/v/vllm-mlx.svg)](https://pypi.org/project/vllm-mlx/) +[![PyPI Downloads](https://img.shields.io/pypi/dm/vllm-mlx.svg)](https://pypi.org/project/vllm-mlx/) +[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/) +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE) +[![Apple Silicon](https://img.shields.io/badge/Apple-Silicon-black.svg)](https://support.apple.com/en-us/HT211814) +[![GitHub stars](https://img.shields.io/github/stars/waybarrios/vllm-mlx.svg?style=social)](https://github.com/waybarrios/vllm-mlx) + +--- + +## ¿Qué es vllm-mlx? + +Un servidor de inferencia estilo vLLM para Macs con Apple Silicon. A diferencia de usar `Ollama` o `mlx-lm` directamente, incluye **continuous batching, paged KV cache, prefix caching y KV cache en SSD**, y expone **tanto OpenAI `/v1/*` como Anthropic `/v1/messages`** desde un solo proceso. Corre LLMs, modelos de visión, audio y embeddings sobre Metal con memoria unificada, sin paso de conversión. + +## Inicio rápido (30 segundos) + +```bash +pip install vllm-mlx +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous-batching +``` + +**SDK de OpenAI:** + +```python +from openai import OpenAI +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") +r = client.chat.completions.create(model="default", messages=[{"role": "user", "content": "Hola!"}]) +print(r.choices[0].message.content) +``` + +**SDK de Anthropic / Claude Code:** + +```bash +export ANTHROPIC_BASE_URL=http://localhost:8000 +export ANTHROPIC_API_KEY=not-needed +claude +``` + +## Características + +### APIs +- **Compatible con OpenAI**: `/v1/chat/completions`, `/v1/completions`, `/v1/embeddings`, `/v1/rerank`, `/v1/responses` +- **Compatible con Anthropic**: `/v1/messages` (streaming, tool use, system prompts) +- **MCP Tool Calling**: 12 parsers (OpenAI, Anthropic, Gemini, Qwen, DeepSeek, Gemma y más) +- **Salida estructurada**: JSON Schema vía `response_format` (lm-format-enforcer) + +### Throughput y memoria +- **Continuous batching**: alto throughput para requests concurrentes +- **Paged KV cache**: eficiente en memoria con prefix sharing +- **KV cache en SSD**: volcá el prefix cache a disco para agentes con contexto largo (`--ssd-cache-dir`) +- **Warm prompts**: precargá prefixes populares al arrancar (`--warm-prompts`) para 1.3-2.25x de TTFT +- **Prefix cache**: basado en trie, compartido entre requests + +### Multimodal +- **Texto + imagen + video + audio** desde un solo servidor +- Modelos de visión: Gemma 3, Gemma 4, Qwen3-VL, Pixtral, Llama vision +- **Audio de entrada** en el chat (bloques `audio_url`) +- **TTS nativo**: 11 voces, 15+ idiomas (Kokoro, Chatterbox, VibeVoice, VoxCPM) +- **STT**: familia Whisper con RTF hasta 197x en M4 Max + +### Razonamiento y avanzado +- **Extracción de razonamiento**: Qwen3, DeepSeek-R1 (`--reasoning-parser`) +- **Reducción de expertos MoE**: `--moe-top-k` para +7-16% en Qwen3-30B-A3B +- **Decodificación especulativa**: `--mtp` para Qwen3-Next +- **Prefill disperso**: `--spec-prefill` basado en atención para reducir TTFT + +### Observabilidad +- **Métricas Prometheus**: endpoint `/metrics` con `--metrics` +- **Benchmarker incluido**: `vllm-mlx bench-serve` para barridos de prompts con salida CSV/JSON + +### Aceleración GPU nativa +- Solo Apple Silicon (M1, M2, M3, M4) con kernels Metal vía MLX +- Memoria unificada, sin conversión de modelos + +## Rendimiento + +**Decode de LLM (M4 Max, 128 GB, greedy, single stream):** + +| Modelo | Tok/s | Memoria | +|--------|------:|--------:| +| Qwen3-0.6B-8bit | 417.9 | 0.7 GB | +| Llama-3.2-3B-Instruct-4bit | 205.6 | 1.8 GB | +| Qwen3-30B-A3B-4bit | 127.7 | ~18 GB | + +**Audio speech-to-text (M4 Max, RTF = real-time factor):** + +| Modelo | RTF | Caso de uso | +|--------|----:|-------------| +| whisper-tiny | 197x | Tiempo real / baja latencia | +| whisper-large-v3-turbo | 55x | Calidad + velocidad | +| whisper-large-v3 | 24x | Máxima precisión | + +Ver [docs/benchmarks/](docs/benchmarks/) para resultados de continuous batching, cuantización de KV cache (4-bit / 8-bit / fp16) y barridos de MoE top-k. + +## Ejemplos + +### API Anthropic (Claude Code, OpenCode) + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit --port 8000 +export ANTHROPIC_BASE_URL=http://localhost:8000 +export ANTHROPIC_API_KEY=not-needed +claude +``` + +### Modelos de razonamiento (Qwen3, DeepSeek-R1) + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 +``` + +```python +r = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "¿Cuánto es 17 * 23?"}], +) +print("Pensamiento:", r.choices[0].message.reasoning) +print("Respuesta:", r.choices[0].message.content) +``` + +### Multimodal (imagen + texto) + +```bash +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit --port 8000 +``` + +```python +r = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": [ + {"type": "text", "text": "¿Qué hay en esta imagen?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/cat.jpg"}}, + ]}], +) +``` + +### Salida estructurada (JSON Schema) + +```python +r = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Lista 3 colores."}], + response_format={ + "type": "json_schema", + "json_schema": { + "schema": {"type": "object", "properties": {"colors": {"type": "array", "items": {"type": "string"}}}} + }, + }, +) +``` + +### Reranking (`/v1/rerank`) + +```bash +curl http://localhost:8000/v1/rerank -H 'Content-Type: application/json' -d '{ + "model": "default", + "query": "inferencia en apple silicon", + "documents": ["MLX es el framework de Apple", "Kernels Metal en M-series", "CUDA en NVIDIA"] +}' +``` + +### Embeddings + +```bash +vllm-mlx serve --embedding-model mlx-community/all-MiniLM-L6-v2-4bit +``` + +```python +emb = client.embeddings.create(model="mlx-community/all-MiniLM-L6-v2-4bit", input=["Hola", "Mundo"]) +``` + +### Audio (TTS / STT) + +```bash +pip install vllm-mlx[audio] +brew install espeak-ng # macOS, necesario para TTS en idiomas no-inglés + +python examples/tts_example.py "Hello, how are you?" --play +python examples/tts_multilingual.py "Hola mundo" --lang es --play +``` + +### Benchmarking incluido + +```bash +vllm-mlx bench-serve --url http://localhost:8000 --concurrency 5 --prompts prompts.txt --output results.csv +``` + +### Métricas Prometheus + +```bash +vllm-mlx serve --metrics +curl http://localhost:8000/metrics +``` + +## Instalación + +**Usando uv (recomendado):** + +```bash +uv tool install vllm-mlx # CLI a nivel sistema +# o dentro de un proyecto +uv pip install vllm-mlx +``` + +**Usando pip:** + +```bash +pip install vllm-mlx + +# Extras de audio +pip install vllm-mlx[audio] +brew install espeak-ng +python -m spacy download en_core_web_sm +``` + +**Desde el código fuente:** + +```bash +git clone https://github.com/waybarrios/vllm-mlx.git +cd vllm-mlx +pip install -e . +``` + +Ver [Guía de instalación](docs/getting-started/installation.md) para todas las opciones. + +## Documentación + +- **Primeros pasos**: [Instalación](docs/getting-started/installation.md) · [Inicio rápido](docs/getting-started/quickstart.md) +- **Servidores y APIs**: [Servidor OpenAI](docs/guides/server.md) · [API Anthropic Messages](docs/guides/server.md#anthropic-messages-api) · [API Python](docs/guides/python-api.md) +- **Características**: [Multimodal](docs/guides/multimodal.md) · [Audio](docs/guides/audio.md) · [Embeddings](docs/guides/embeddings.md) · [Razonamiento](docs/guides/reasoning.md) · [MCP y Tool Calling](docs/guides/mcp-tools.md) · [Parsers de tools](docs/guides/tool-calling.md) +- **Rendimiento**: [Continuous Batching](docs/guides/continuous-batching.md) · [Warm Prompts](docs/guides/warm-prompts.md) · [MoE Top-K](docs/guides/moe-top-k.md) +- **Referencia**: [CLI](docs/reference/cli.md) · [Modelos](docs/reference/models.md) · [Configuración](docs/reference/configuration.md) +- **Benchmarks**: [LLM](docs/benchmarks/llm.md) · [Imagen](docs/benchmarks/image.md) · [Video](docs/benchmarks/video.md) · [Audio](docs/benchmarks/audio.md) + +## Arquitectura + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ Servidor vllm-mlx │ +│ OpenAI /v1/* · Anthropic /v1/messages · /v1/rerank · /metrics │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Continuous batching · Paged KV cache · Prefix cache · SSD tiering │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ┌─────────────┬────────────┴────────────┬─────────────┐ + ▼ ▼ ▼ ▼ +┌───────────────┐ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ +│ mlx-lm │ │ mlx-vlm │ │ mlx-audio │ │mlx-embeddings │ +│ (LLMs) │ │ (Visión) │ │ (TTS + STT) │ │ (Embeddings) │ +└───────────────┘ └───────────────┘ └───────────────┘ └───────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ MLX · kernels Metal · memoria unificada │ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +## Contribuir + +Bienvenidos bug fixes, trabajo de performance, docs y benchmarks en distintos chips de Apple Silicon. Ver [Guía de contribución](docs/development/contributing.md). + +## Licencia + +Apache 2.0. Ver [LICENSE](LICENSE). + +## Citación + +```bibtex +@software{vllm_mlx2025, + author = {Barrios, Wayner}, + title = {vllm-mlx: Apple Silicon MLX Backend for vLLM}, + year = {2025}, + url = {https://github.com/waybarrios/vllm-mlx}, + note = {Native GPU-accelerated LLM and vision-language model inference on Apple Silicon} +} +``` + +## Agradecimientos + +- [MLX](https://github.com/ml-explore/mlx). Framework de ML de Apple. +- [mlx-lm](https://github.com/ml-explore/mlx-lm). Librería de inferencia de LLM. +- [mlx-vlm](https://github.com/Blaizzy/mlx-vlm). Modelos de visión y lenguaje. +- [mlx-audio](https://github.com/Blaizzy/mlx-audio). Text-to-Speech y Speech-to-Text. +- [mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings). Embeddings de texto. +- [Rapid-MLX](https://github.com/raullenchai/Rapid-MLX). Fork comunitario de vllm-mlx. +- [vLLM](https://github.com/vllm-project/vllm). Servicio de LLM de alto throughput. vllm-mlx está inspirado en vLLM y adopta su diseño de continuous-batching y paged KV-cache para Apple Silicon vía MLX. + +## Historia de stars + +[![Star History Chart](https://api.star-history.com/svg?repos=waybarrios/vllm-mlx&type=Date)](https://star-history.com/#waybarrios/vllm-mlx&Date) + +--- + +**Si vllm-mlx te sirvió, por favor dale una star al repo. Ayuda a que más devs de Apple Silicon lo encuentren.** diff --git a/README.fr.md b/README.fr.md new file mode 100644 index 000000000..93fb98cdf --- /dev/null +++ b/README.fr.md @@ -0,0 +1,302 @@ +# vllm-mlx + +**Lire ceci dans d'autres langues :** [English](README.md) · [Español](README.es.md) · [Français](README.fr.md) · [中文](README.zh.md) + +**Continuous batching + API OpenAI et Anthropic dans un seul serveur. Inférence native sur Apple Silicon.** + +[![PyPI version](https://img.shields.io/pypi/v/vllm-mlx.svg)](https://pypi.org/project/vllm-mlx/) +[![PyPI Downloads](https://img.shields.io/pypi/dm/vllm-mlx.svg)](https://pypi.org/project/vllm-mlx/) +[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/) +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE) +[![Apple Silicon](https://img.shields.io/badge/Apple-Silicon-black.svg)](https://support.apple.com/en-us/HT211814) +[![GitHub stars](https://img.shields.io/github/stars/waybarrios/vllm-mlx.svg?style=social)](https://github.com/waybarrios/vllm-mlx) + +--- + +## Qu'est-ce que vllm-mlx ? + +Un serveur d'inférence de type vLLM pour les Macs Apple Silicon. Contrairement à l'utilisation directe de `Ollama` ou `mlx-lm`, il embarque **continuous batching, paged KV cache, prefix caching et cache KV sur SSD**, et expose **à la fois OpenAI `/v1/*` et Anthropic `/v1/messages`** depuis un seul processus. Exécutez des LLMs, modèles de vision, audio et embeddings sur Metal avec mémoire unifiée, sans étape de conversion. + +## Démarrage rapide (30 secondes) + +```bash +pip install vllm-mlx +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous-batching +``` + +**SDK OpenAI :** + +```python +from openai import OpenAI +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") +r = client.chat.completions.create(model="default", messages=[{"role": "user", "content": "Salut !"}]) +print(r.choices[0].message.content) +``` + +**SDK Anthropic / Claude Code :** + +```bash +export ANTHROPIC_BASE_URL=http://localhost:8000 +export ANTHROPIC_API_KEY=not-needed +claude +``` + +## Fonctionnalités + +### APIs +- **Compatible OpenAI** : `/v1/chat/completions`, `/v1/completions`, `/v1/embeddings`, `/v1/rerank`, `/v1/responses` +- **Compatible Anthropic** : `/v1/messages` (streaming, tool use, system prompts) +- **MCP Tool Calling** : 12 parsers (OpenAI, Anthropic, Gemini, Qwen, DeepSeek, Gemma et plus) +- **Sortie structurée** : JSON Schema via `response_format` (lm-format-enforcer) + +### Débit et mémoire +- **Continuous batching** : haut débit pour des requêtes concurrentes +- **Paged KV cache** : efficace en mémoire avec prefix sharing +- **Cache KV sur SSD** : déversez le prefix cache sur disque pour les agents à long contexte (`--ssd-cache-dir`) +- **Warm prompts** : pré-chargez les préfixes populaires au démarrage (`--warm-prompts`) pour un TTFT 1.3-2.25x +- **Prefix cache** : basé sur trie, partagé entre les requêtes + +### Multimodal +- **Texte + image + vidéo + audio** depuis un seul serveur +- Modèles de vision : Gemma 3, Gemma 4, Qwen3-VL, Pixtral, Llama vision +- **Audio en entrée** dans le chat (blocs `audio_url`) +- **TTS natif** : 11 voix, 15+ langues (Kokoro, Chatterbox, VibeVoice, VoxCPM) +- **STT** : famille Whisper avec RTF jusqu'à 197x sur M4 Max + +### Raisonnement et avancé +- **Extraction du raisonnement** : Qwen3, DeepSeek-R1 (`--reasoning-parser`) +- **Réduction d'experts MoE** : `--moe-top-k` pour +7-16% sur Qwen3-30B-A3B +- **Décodage spéculatif** : `--mtp` pour Qwen3-Next +- **Prefill creux** : `--spec-prefill` basé sur l'attention pour réduire le TTFT + +### Observabilité +- **Métriques Prometheus** : endpoint `/metrics` avec `--metrics` +- **Benchmark intégré** : `vllm-mlx bench-serve` pour des balayages de prompts en CSV/JSON + +### Accélération GPU native +- Apple Silicon uniquement (M1, M2, M3, M4) avec kernels Metal via MLX +- Mémoire unifiée, sans conversion de modèles + +## Performance + +**Decode LLM (M4 Max, 128 Go, greedy, single stream) :** + +| Modèle | Tok/s | Mémoire | +|--------|------:|--------:| +| Qwen3-0.6B-8bit | 417.9 | 0.7 Go | +| Llama-3.2-3B-Instruct-4bit | 205.6 | 1.8 Go | +| Qwen3-30B-A3B-4bit | 127.7 | ~18 Go | + +**Audio speech-to-text (M4 Max, RTF = real-time factor) :** + +| Modèle | RTF | Cas d'usage | +|--------|----:|-------------| +| whisper-tiny | 197x | Temps réel / faible latence | +| whisper-large-v3-turbo | 55x | Qualité + vitesse | +| whisper-large-v3 | 24x | Précision maximale | + +Voir [docs/benchmarks/](docs/benchmarks/) pour les résultats continuous batching, la quantification du KV cache (4-bit / 8-bit / fp16) et les balayages MoE top-k. + +## Exemples + +### API Anthropic (Claude Code, OpenCode) + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit --port 8000 +export ANTHROPIC_BASE_URL=http://localhost:8000 +export ANTHROPIC_API_KEY=not-needed +claude +``` + +### Modèles de raisonnement (Qwen3, DeepSeek-R1) + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 +``` + +```python +r = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Combien font 17 * 23 ?"}], +) +print("Raisonnement :", r.choices[0].message.reasoning) +print("Réponse :", r.choices[0].message.content) +``` + +### Multimodal (image + texte) + +```bash +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit --port 8000 +``` + +```python +r = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": [ + {"type": "text", "text": "Qu'y a-t-il sur cette image ?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/cat.jpg"}}, + ]}], +) +``` + +### Sortie structurée (JSON Schema) + +```python +r = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Liste 3 couleurs."}], + response_format={ + "type": "json_schema", + "json_schema": { + "schema": {"type": "object", "properties": {"colors": {"type": "array", "items": {"type": "string"}}}} + }, + }, +) +``` + +### Reranking (`/v1/rerank`) + +```bash +curl http://localhost:8000/v1/rerank -H 'Content-Type: application/json' -d '{ + "model": "default", + "query": "inférence apple silicon", + "documents": ["MLX est le framework d Apple", "Kernels Metal sur M-series", "CUDA sur NVIDIA"] +}' +``` + +### Embeddings + +```bash +vllm-mlx serve --embedding-model mlx-community/all-MiniLM-L6-v2-4bit +``` + +```python +emb = client.embeddings.create(model="mlx-community/all-MiniLM-L6-v2-4bit", input=["Bonjour", "Monde"]) +``` + +### Audio (TTS / STT) + +```bash +pip install vllm-mlx[audio] +brew install espeak-ng # macOS, nécessaire pour TTS non-anglais + +python examples/tts_example.py "Hello, how are you?" --play +python examples/tts_multilingual.py "Bonjour le monde" --lang fr --play +``` + +### Benchmarking intégré + +```bash +vllm-mlx bench-serve --url http://localhost:8000 --concurrency 5 --prompts prompts.txt --output results.csv +``` + +### Métriques Prometheus + +```bash +vllm-mlx serve --metrics +curl http://localhost:8000/metrics +``` + +## Installation + +**Avec uv (recommandé) :** + +```bash +uv tool install vllm-mlx # CLI global +# ou dans un projet +uv pip install vllm-mlx +``` + +**Avec pip :** + +```bash +pip install vllm-mlx + +# Extras audio +pip install vllm-mlx[audio] +brew install espeak-ng +python -m spacy download en_core_web_sm +``` + +**Depuis les sources :** + +```bash +git clone https://github.com/waybarrios/vllm-mlx.git +cd vllm-mlx +pip install -e . +``` + +Voir le [Guide d'installation](docs/getting-started/installation.md) pour toutes les options. + +## Documentation + +- **Premiers pas** : [Installation](docs/getting-started/installation.md) · [Démarrage rapide](docs/getting-started/quickstart.md) +- **Serveurs et APIs** : [Serveur OpenAI](docs/guides/server.md) · [API Anthropic Messages](docs/guides/server.md#anthropic-messages-api) · [API Python](docs/guides/python-api.md) +- **Fonctionnalités** : [Multimodal](docs/guides/multimodal.md) · [Audio](docs/guides/audio.md) · [Embeddings](docs/guides/embeddings.md) · [Raisonnement](docs/guides/reasoning.md) · [MCP et Tool Calling](docs/guides/mcp-tools.md) · [Parsers de tools](docs/guides/tool-calling.md) +- **Performance** : [Continuous Batching](docs/guides/continuous-batching.md) · [Warm Prompts](docs/guides/warm-prompts.md) · [MoE Top-K](docs/guides/moe-top-k.md) +- **Référence** : [CLI](docs/reference/cli.md) · [Modèles](docs/reference/models.md) · [Configuration](docs/reference/configuration.md) +- **Benchmarks** : [LLM](docs/benchmarks/llm.md) · [Image](docs/benchmarks/image.md) · [Vidéo](docs/benchmarks/video.md) · [Audio](docs/benchmarks/audio.md) + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ Serveur vllm-mlx │ +│ OpenAI /v1/* · Anthropic /v1/messages · /v1/rerank · /metrics │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Continuous batching · Paged KV cache · Prefix cache · SSD tiering │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ┌─────────────┬────────────┴────────────┬─────────────┐ + ▼ ▼ ▼ ▼ +┌───────────────┐ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ +│ mlx-lm │ │ mlx-vlm │ │ mlx-audio │ │mlx-embeddings │ +│ (LLMs) │ │ (Vision) │ │ (TTS + STT) │ │ (Embeddings) │ +└───────────────┘ └───────────────┘ └───────────────┘ └───────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ MLX · kernels Metal · mémoire unifiée │ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +## Contribuer + +Les corrections de bugs, travaux de performance, docs et benchmarks sur différentes puces Apple Silicon sont bienvenus. Voir le [Guide de contribution](docs/development/contributing.md). + +## Licence + +Apache 2.0. Voir [LICENSE](LICENSE). + +## Citation + +```bibtex +@software{vllm_mlx2025, + author = {Barrios, Wayner}, + title = {vllm-mlx: Apple Silicon MLX Backend for vLLM}, + year = {2025}, + url = {https://github.com/waybarrios/vllm-mlx}, + note = {Native GPU-accelerated LLM and vision-language model inference on Apple Silicon} +} +``` + +## Remerciements + +- [MLX](https://github.com/ml-explore/mlx). Framework ML d'Apple. +- [mlx-lm](https://github.com/ml-explore/mlx-lm). Bibliothèque d'inférence LLM. +- [mlx-vlm](https://github.com/Blaizzy/mlx-vlm). Modèles vision-langage. +- [mlx-audio](https://github.com/Blaizzy/mlx-audio). Text-to-Speech et Speech-to-Text. +- [mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings). Embeddings de texte. +- [Rapid-MLX](https://github.com/raullenchai/Rapid-MLX). Fork communautaire de vllm-mlx. +- [vLLM](https://github.com/vllm-project/vllm). Service LLM haut débit. vllm-mlx s'inspire de vLLM et adopte sa conception continuous-batching et paged KV-cache pour Apple Silicon via MLX. + +## Historique des stars + +[![Star History Chart](https://api.star-history.com/svg?repos=waybarrios/vllm-mlx&type=Date)](https://star-history.com/#waybarrios/vllm-mlx&Date) + +--- + +**Si vllm-mlx vous a été utile, mettez une étoile au repo. Cela aide plus de développeurs Apple Silicon à le trouver.** diff --git a/README.md b/README.md index 7fc61fe46..1865b8e34 100644 --- a/README.md +++ b/README.md @@ -1,402 +1,302 @@ -# vLLM-MLX +# vllm-mlx -**vLLM-like inference for Apple Silicon** - GPU-accelerated Text, Image, Video & Audio on Mac +**Read this in other languages:** [English](README.md) · [Español](README.es.md) · [Français](README.fr.md) · [中文](README.zh.md) -[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE) +**Continuous batching + OpenAI + Anthropic APIs in one server. Native Apple Silicon inference.** + +[![PyPI version](https://img.shields.io/pypi/v/vllm-mlx.svg)](https://pypi.org/project/vllm-mlx/) +[![PyPI Downloads](https://img.shields.io/pypi/dm/vllm-mlx.svg)](https://pypi.org/project/vllm-mlx/) [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/) +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE) [![Apple Silicon](https://img.shields.io/badge/Apple-Silicon-black.svg)](https://support.apple.com/en-us/HT211814) -[![GitHub](https://img.shields.io/badge/GitHub-waybarrios%2Fvllm--mlx-blue?logo=github)](https://github.com/waybarrios/vllm-mlx) +[![GitHub stars](https://img.shields.io/github/stars/waybarrios/vllm-mlx.svg?style=social)](https://github.com/waybarrios/vllm-mlx) -## Overview +--- -vllm-mlx brings native Apple Silicon GPU acceleration to vLLM by integrating: +## What is vllm-mlx? -- **[MLX](https://github.com/ml-explore/mlx)**: Apple's ML framework with unified memory and Metal kernels -- **[mlx-lm](https://github.com/ml-explore/mlx-lm)**: Optimized LLM inference with KV cache and quantization -- **[mlx-vlm](https://github.com/Blaizzy/mlx-vlm)**: Vision-language models for multimodal inference -- **[mlx-audio](https://github.com/Blaizzy/mlx-audio)**: Speech-to-Text and Text-to-Speech with native voices -- **[mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings)**: Text embeddings for semantic search and RAG +A vLLM-style inference server for Apple Silicon Macs. Unlike `Ollama` or `mlx-lm` used directly, it ships **continuous batching, paged KV cache, prefix caching, and SSD-tiered cache**, and exposes **both OpenAI `/v1/*` and Anthropic `/v1/messages`** from a single process. Run LLMs, vision models, audio, and embeddings on Metal with unified memory, no conversion step. -## Features +## Quick start (30 seconds) -- **Multimodal** - Text, Image, Video & Audio in one platform -- **Native GPU acceleration** on Apple Silicon (M1, M2, M3, M4) -- **Native TTS voices** - Spanish, French, Chinese, Japanese + 5 more languages -- **OpenAI API compatible** - drop-in replacement for OpenAI client -- **Anthropic Messages API** - native `/v1/messages` endpoint for Claude Code and OpenCode -- **Embeddings** - OpenAI-compatible `/v1/embeddings` endpoint with mlx-embeddings -- **Reasoning Models** - extract thinking process from Qwen3, DeepSeek-R1 -- **MCP Tool Calling** - integrate external tools via Model Context Protocol -- **Paged KV Cache** - memory-efficient caching with prefix sharing -- **Continuous Batching** - high throughput for multiple concurrent users +```bash +pip install vllm-mlx +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous-batching +``` -## Quick Start +**OpenAI SDK:** -### Installation +```python +from openai import OpenAI +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") +r = client.chat.completions.create(model="default", messages=[{"role": "user", "content": "Hi!"}]) +print(r.choices[0].message.content) +``` -**Using uv (recommended):** +**Anthropic SDK / Claude Code:** ```bash -# Install as CLI tool (system-wide) -uv tool install git+https://github.com/waybarrios/vllm-mlx.git - -# Or install in a project/virtual environment -uv pip install git+https://github.com/waybarrios/vllm-mlx.git +export ANTHROPIC_BASE_URL=http://localhost:8000 +export ANTHROPIC_API_KEY=not-needed +claude ``` -**Using pip:** +## Features -```bash -# Install from GitHub -pip install git+https://github.com/waybarrios/vllm-mlx.git +### APIs +- **OpenAI-compatible**: `/v1/chat/completions`, `/v1/completions`, `/v1/embeddings`, `/v1/rerank`, `/v1/responses` +- **Anthropic-compatible**: `/v1/messages` (streaming, tool use, system prompts) +- **MCP Tool Calling**: 12 parsers (OpenAI, Anthropic, Gemini, Qwen, DeepSeek, Gemma, and more) +- **Structured output**: JSON Schema via `response_format` (lm-format-enforcer) + +### Throughput & memory +- **Continuous batching**: high throughput for concurrent requests +- **Paged KV cache**: memory-efficient with prefix sharing +- **SSD-tiered KV cache**: spill prefix cache to disk for long-context agents (`--ssd-cache-dir`) +- **Warm prompts**: preload popular prefixes at startup (`--warm-prompts`) for 1.3-2.25x TTFT +- **Prefix cache**: trie-based, shared across requests + +### Multimodal +- **Text + image + video + audio** from one server +- Vision models: Gemma 3, Gemma 4, Qwen3-VL, Pixtral, Llama vision +- **Audio input** in chat (`audio_url` content blocks) +- **Native TTS**: 11 voices, 15+ languages (Kokoro, Chatterbox, VibeVoice, VoxCPM) +- **STT**: Whisper family with RTF up to 197x on M4 Max + +### Reasoning & advanced +- **Reasoning extraction**: Qwen3, DeepSeek-R1 (`--reasoning-parser`) +- **MoE expert reduction**: `--moe-top-k` for +7-16% on Qwen3-30B-A3B +- **Speculative decoding**: `--mtp` for Qwen3-Next +- **Sparse prefill**: attention-based `--spec-prefill` for TTFT reduction + +### Observability +- **Prometheus metrics**: `/metrics` endpoint with `--metrics` +- **Built-in benchmarker**: `vllm-mlx bench-serve` for prompt sweeps with CSV/JSON output + +### Native GPU acceleration +- Apple Silicon only (M1, M2, M3, M4) with Metal kernels via MLX +- Unified memory, no model conversion -# Or clone and install in development mode -git clone https://github.com/waybarrios/vllm-mlx.git -cd vllm-mlx -pip install -e . -``` +## Performance -### Start Server +**LLM decode (M4 Max, 128 GB, greedy, single stream):** -```bash -# Simple mode (single user, max throughput) -vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 +| Model | Tok/s | Memory | +|-------|------:|-------:| +| Qwen3-0.6B-8bit | 417.9 | 0.7 GB | +| Llama-3.2-3B-Instruct-4bit | 205.6 | 1.8 GB | +| Qwen3-30B-A3B-4bit | 127.7 | ~18 GB | -# Continuous batching (multiple users) -vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous-batching +**Audio speech-to-text (M4 Max, RTF = real-time factor):** -# With API key authentication -vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --api-key your-secret-key -``` +| Model | RTF | Use case | +|-------|----:|----------| +| whisper-tiny | 197x | Real-time / low latency | +| whisper-large-v3-turbo | 55x | Quality + speed | +| whisper-large-v3 | 24x | Highest accuracy | -### Use with OpenAI SDK +See [docs/benchmarks/](docs/benchmarks/) for continuous-batching results, KV-cache quantization (4-bit / 8-bit / fp16), and MoE top-k sweeps. -```python -from openai import OpenAI +## Examples -# Without API key (local development) -client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") +### Anthropic API (Claude Code, OpenCode) + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit --port 8000 +export ANTHROPIC_BASE_URL=http://localhost:8000 +export ANTHROPIC_API_KEY=not-needed +claude +``` + +### Reasoning models (Qwen3, DeepSeek-R1) -# With API key (production) -client = OpenAI(base_url="http://localhost:8000/v1", api_key="your-secret-key") +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 +``` -response = client.chat.completions.create( +```python +r = client.chat.completions.create( model="default", - messages=[{"role": "user", "content": "Hello!"}], + messages=[{"role": "user", "content": "What is 17 * 23?"}], ) -print(response.choices[0].message.content) +print("Thinking:", r.choices[0].message.reasoning) +print("Answer:", r.choices[0].message.content) ``` -### Use with Anthropic SDK +### Multimodal (image + text) -vllm-mlx exposes an Anthropic-compatible `/v1/messages` endpoint, so tools like Claude Code and OpenCode can connect directly. +```bash +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit --port 8000 +``` ```python -from anthropic import Anthropic +r = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": [ + {"type": "text", "text": "What is in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/cat.jpg"}}, + ]}], +) +``` -client = Anthropic(base_url="http://localhost:8000", api_key="not-needed") +### Structured output (JSON Schema) -response = client.messages.create( +```python +r = client.chat.completions.create( model="default", - max_tokens=256, - messages=[{"role": "user", "content": "Hello!"}] + messages=[{"role": "user", "content": "List 3 colors."}], + response_format={ + "type": "json_schema", + "json_schema": { + "schema": {"type": "object", "properties": {"colors": {"type": "array", "items": {"type": "string"}}}} + }, + }, ) -print(response.content[0].text) ``` -To use with Claude Code: +### Reranking (`/v1/rerank`) ```bash -export ANTHROPIC_BASE_URL=http://localhost:8000 -export ANTHROPIC_API_KEY=not-needed -claude +curl http://localhost:8000/v1/rerank -H 'Content-Type: application/json' -d '{ + "model": "default", + "query": "apple silicon inference", + "documents": ["MLX is Apples framework", "Metal kernels on M-series", "CUDA on NVIDIA"] +}' ``` -See [Anthropic Messages API docs](docs/guides/server.md#anthropic-messages-api) for streaming, tool calling, system messages, and token counting. - -### Multimodal (Images & Video) +### Embeddings ```bash -vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit --port 8000 +vllm-mlx serve --embedding-model mlx-community/all-MiniLM-L6-v2-4bit ``` ```python -response = client.chat.completions.create( - model="default", - messages=[{ - "role": "user", - "content": [ - {"type": "text", "text": "What's in this image?"}, - {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}} - ] - }] -) +emb = client.embeddings.create(model="mlx-community/all-MiniLM-L6-v2-4bit", input=["Hello", "World"]) ``` -### Audio (TTS/STT) +### Audio (TTS / STT) ```bash -# Install audio dependencies pip install vllm-mlx[audio] -python -m spacy download en_core_web_sm -brew install espeak-ng # macOS, for non-English languages -``` +brew install espeak-ng # macOS, needed for non-English TTS -```bash -# Text-to-Speech (English) python examples/tts_example.py "Hello, how are you?" --play - -# Text-to-Speech (Spanish) python examples/tts_multilingual.py "Hola mundo" --lang es --play - -# List available models and languages -python examples/tts_multilingual.py --list-models -python examples/tts_multilingual.py --list-languages ``` -**Supported TTS Models:** -| Model | Languages | Description | -|-------|-----------|-------------| -| Kokoro | EN, ES, FR, JA, ZH, IT, PT, HI | Fast, 82M params, 11 voices | -| Chatterbox | 15+ languages | Expressive, voice cloning | -| VibeVoice | EN | Realtime, low latency | -| VoxCPM | ZH, EN | High quality Chinese/English | - -### Reasoning Models - -Extract the thinking process from reasoning models like Qwen3 and DeepSeek-R1: +### Built-in benchmarking ```bash -# Start server with reasoning parser -vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 +vllm-mlx bench-serve --url http://localhost:8000 --concurrency 5 --prompts prompts.txt --output results.csv ``` -```python -response = client.chat.completions.create( - model="default", - messages=[{"role": "user", "content": "What is 17 × 23?"}] -) +### Prometheus metrics -# Access reasoning separately from the answer -print("Thinking:", response.choices[0].message.reasoning) -print("Answer:", response.choices[0].message.content) +```bash +vllm-mlx serve --metrics +curl http://localhost:8000/metrics ``` -**Supported Parsers:** -| Parser | Models | Description | -|--------|--------|-------------| -| `qwen3` | Qwen3 series | Requires both `` and `` tags | -| `deepseek_r1` | DeepSeek-R1 | Handles implicit `` tag | +## Installation -### Embeddings +**Using uv (recommended):** -Generate text embeddings for semantic search, RAG, and similarity: +```bash +uv tool install vllm-mlx # CLI, system-wide +# or in a project +uv pip install vllm-mlx +``` + +**Using pip:** ```bash -# Start server with an embedding model pre-loaded -vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --embedding-model mlx-community/all-MiniLM-L6-v2-4bit +pip install vllm-mlx + +# Audio extras +pip install vllm-mlx[audio] +brew install espeak-ng +python -m spacy download en_core_web_sm ``` -```python -# Generate embeddings using the OpenAI SDK -embeddings = client.embeddings.create( - model="mlx-community/all-MiniLM-L6-v2-4bit", - input=["Hello world", "How are you?"] -) -print(f"Dimensions: {len(embeddings.data[0].embedding)}") +**From source:** + +```bash +git clone https://github.com/waybarrios/vllm-mlx.git +cd vllm-mlx +pip install -e . ``` -See [Embeddings Guide](docs/guides/embeddings.md) for details on supported models and lazy loading. +See [Installation Guide](docs/getting-started/installation.md) for full options. ## Documentation -For full documentation, see the [docs](docs/) directory: - -- **Getting Started** - - [Installation](docs/getting-started/installation.md) - - [Quick Start](docs/getting-started/quickstart.md) - -- **User Guides** - - [OpenAI-Compatible Server](docs/guides/server.md) - - [Anthropic Messages API](docs/guides/server.md#anthropic-messages-api) - - [Python API](docs/guides/python-api.md) - - [Multimodal (Images & Video)](docs/guides/multimodal.md) - - [Audio (STT/TTS)](docs/guides/audio.md) - - [Embeddings](docs/guides/embeddings.md) - - [Reasoning Models](docs/guides/reasoning.md) - - [MCP & Tool Calling](docs/guides/mcp-tools.md) - - [Continuous Batching](docs/guides/continuous-batching.md) - -- **Reference** - - [CLI Commands](docs/reference/cli.md) - - [Supported Models](docs/reference/models.md) - - [Configuration](docs/reference/configuration.md) - -- **Benchmarks** - - [LLM Benchmarks](docs/benchmarks/llm.md) - - [Image Benchmarks](docs/benchmarks/image.md) - - [Video Benchmarks](docs/benchmarks/video.md) - - [Audio Benchmarks](docs/benchmarks/audio.md) +- **Getting started**: [Installation](docs/getting-started/installation.md) · [Quick Start](docs/getting-started/quickstart.md) +- **Servers & APIs**: [OpenAI server](docs/guides/server.md) · [Anthropic Messages API](docs/guides/server.md#anthropic-messages-api) · [Python API](docs/guides/python-api.md) +- **Features**: [Multimodal](docs/guides/multimodal.md) · [Audio](docs/guides/audio.md) · [Embeddings](docs/guides/embeddings.md) · [Reasoning](docs/guides/reasoning.md) · [MCP & Tool Calling](docs/guides/mcp-tools.md) · [Tool Parsers](docs/guides/tool-calling.md) +- **Performance**: [Continuous Batching](docs/guides/continuous-batching.md) · [Warm Prompts](docs/guides/warm-prompts.md) · [MoE Top-K](docs/guides/moe-top-k.md) +- **Reference**: [CLI](docs/reference/cli.md) · [Models](docs/reference/models.md) · [Configuration](docs/reference/configuration.md) +- **Benchmarks**: [LLM](docs/benchmarks/llm.md) · [Image](docs/benchmarks/image.md) · [Video](docs/benchmarks/video.md) · [Audio](docs/benchmarks/audio.md) ## Architecture ``` ┌─────────────────────────────────────────────────────────────────────────┐ -│ vLLM API Layer │ -│ (OpenAI-compatible interface) │ +│ vllm-mlx Server │ +│ OpenAI /v1/* · Anthropic /v1/messages · /v1/rerank · /metrics │ └─────────────────────────────────────────────────────────────────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────────────────┐ -│ MLXPlatform │ -│ (vLLM platform plugin for Apple Silicon) │ +│ Continuous batching · Paged KV cache · Prefix cache · SSD tiering │ └─────────────────────────────────────────────────────────────────────────┘ │ ┌─────────────┬────────────┴────────────┬─────────────┐ ▼ ▼ ▼ ▼ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ │ mlx-lm │ │ mlx-vlm │ │ mlx-audio │ │mlx-embeddings │ -│(LLM inference)│ │ (Vision+LLM) │ │ (TTS + STT) │ │ (Embeddings) │ +│ (LLMs) │ │ (Vision) │ │ (TTS + STT) │ │ (Embeddings) │ └───────────────┘ └───────────────┘ └───────────────┘ └───────────────┘ - │ │ │ │ - └─────────────┴─────────────────────────┴─────────────┘ │ ▼ ┌─────────────────────────────────────────────────────────────────────────┐ -│ MLX │ -│ (Apple ML Framework - Metal kernels) │ +│ MLX · Metal kernels · Unified memory │ └─────────────────────────────────────────────────────────────────────────┘ ``` -## Performance - -**LLM Performance (M4 Max, 128GB):** - -| Model | Speed | Memory | -|-------|-------|--------| -| Qwen3-0.6B-8bit | 402 tok/s | 0.7 GB | -| Llama-3.2-1B-4bit | 464 tok/s | 0.7 GB | -| Llama-3.2-3B-4bit | 200 tok/s | 1.8 GB | - -**Continuous Batching (5 concurrent requests):** - -| Model | Single | Batched | Speedup | -|-------|--------|---------|---------| -| Qwen3-0.6B-8bit | 328 tok/s | 1112 tok/s | **3.4x** | -| Llama-3.2-1B-4bit | 299 tok/s | 613 tok/s | **2.0x** | - -**Audio - Speech-to-Text (M4 Max, 128GB):** - -| Model | RTF* | Use Case | -|-------|------|----------| -| whisper-tiny | **197x** | Real-time, low latency | -| whisper-large-v3-turbo | **55x** | Best quality/speed balance | -| whisper-large-v3 | **24x** | Highest accuracy | - -*RTF = Real-Time Factor. RTF of 100x means 1 minute transcribes in ~0.6 seconds. - -See [benchmarks](docs/benchmarks/) for detailed results. - -## Gemma 3 Support - -vllm-mlx includes native support for Gemma 3 vision models. Gemma 3 is automatically detected as MLLM. - -### Usage - -```bash -# Start server with Gemma 3 -vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000 - -# Verify it loaded as MLLM (not LLM) -curl http://localhost:8000/health -# Should show: "model_type": "mllm" -``` - -### Long Context Patch (mlx-vlm) - -Gemma 3's default `sliding_window=1024` limits context to ~10K tokens on Apple Silicon (Metal GPU timeout at higher context). To enable longer context (up to ~50K tokens), patch mlx-vlm: - -**Location:** `~/.../site-packages/mlx_vlm/models/gemma3/language.py` - -Find the `make_cache` method and replace with: - -```python -def make_cache(self): - import os - # Set GEMMA3_SLIDING_WINDOW=8192 for ~40K context - # Set GEMMA3_SLIDING_WINDOW=0 for ~50K context (full KVCache) - sliding_window = int(os.environ.get('GEMMA3_SLIDING_WINDOW', self.config.sliding_window)) - - caches = [] - for i in range(self.config.num_hidden_layers): - if ( - i % self.config.sliding_window_pattern - == self.config.sliding_window_pattern - 1 - ): - caches.append(KVCache()) - elif sliding_window == 0: - caches.append(KVCache()) # Full context for all layers - else: - caches.append(RotatingKVCache(max_size=sliding_window, keep=0)) - return caches -``` - -**Usage:** - -```bash -# Default (~10K max context) -vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000 - -# Extended context (~40K max) -GEMMA3_SLIDING_WINDOW=8192 vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000 - -# Maximum context (~50K max) -GEMMA3_SLIDING_WINDOW=0 vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000 -``` - -**Benchmark Results (M4 Max 128GB):** - -| Setting | Max Context | Memory | -|---------|-------------|--------| -| Default (1024) | ~10K tokens | ~16GB | -| `GEMMA3_SLIDING_WINDOW=8192` | ~40K tokens | ~25GB | -| `GEMMA3_SLIDING_WINDOW=0` | ~50K tokens | ~35GB | - ## Contributing -We welcome contributions! See [Contributing Guide](docs/development/contributing.md) for details. - -- Bug fixes and improvements -- Performance optimizations -- Documentation improvements -- Benchmarks on different Apple Silicon chips - -Submit PRs to: [https://github.com/waybarrios/vllm-mlx](https://github.com/waybarrios/vllm-mlx) +Bug fixes, perf work, docs, and benchmarks on different Apple Silicon chips all welcome. See the [Contributing Guide](docs/development/contributing.md). ## License -Apache 2.0 - see [LICENSE](LICENSE) for details. +Apache 2.0. See [LICENSE](LICENSE). ## Citation -If you use vLLM-MLX in your research or project, please cite: - ```bibtex @software{vllm_mlx2025, author = {Barrios, Wayner}, - title = {vLLM-MLX: Apple Silicon MLX Backend for vLLM}, - year = {2025}, - url = {https://github.com/waybarrios/vllm-mlx}, - note = {Native GPU-accelerated LLM and vision-language model inference on Apple Silicon} + title = {vllm-mlx: Apple Silicon MLX Backend for vLLM}, + year = {2025}, + url = {https://github.com/waybarrios/vllm-mlx}, + note = {Native GPU-accelerated LLM and vision-language model inference on Apple Silicon} } ``` ## Acknowledgments -- [MLX](https://github.com/ml-explore/mlx) - Apple's ML framework -- [mlx-lm](https://github.com/ml-explore/mlx-lm) - LLM inference library -- [mlx-vlm](https://github.com/Blaizzy/mlx-vlm) - Vision-language models -- [mlx-audio](https://github.com/Blaizzy/mlx-audio) - Text-to-Speech and Speech-to-Text -- [mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings) - Text embeddings -- [Rapid-MLX](https://github.com/raullenchai/Rapid-MLX) - Community fork of vllm-mlx -- [vLLM](https://github.com/vllm-project/vllm) - High-throughput LLM serving +- [MLX](https://github.com/ml-explore/mlx). Apple's ML framework. +- [mlx-lm](https://github.com/ml-explore/mlx-lm). LLM inference library. +- [mlx-vlm](https://github.com/Blaizzy/mlx-vlm). Vision-language models. +- [mlx-audio](https://github.com/Blaizzy/mlx-audio). Text-to-Speech and Speech-to-Text. +- [mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings). Text embeddings. +- [Rapid-MLX](https://github.com/raullenchai/Rapid-MLX). Community fork of vllm-mlx. +- [vLLM](https://github.com/vllm-project/vllm). High-throughput LLM serving. vllm-mlx is inspired by vLLM and adopts its continuous-batching and paged KV-cache design for Apple Silicon via MLX. + +## Star history + +[![Star History Chart](https://api.star-history.com/svg?repos=waybarrios/vllm-mlx&type=Date)](https://star-history.com/#waybarrios/vllm-mlx&Date) + +--- + +**If vllm-mlx helped you, please star the repo. It helps more Apple Silicon devs find it.** diff --git a/README.zh.md b/README.zh.md new file mode 100644 index 000000000..affbba16c --- /dev/null +++ b/README.zh.md @@ -0,0 +1,302 @@ +# vllm-mlx + +**其他语言版本:** [English](README.md) · [Español](README.es.md) · [Français](README.fr.md) · [中文](README.zh.md) + +**连续批处理 + OpenAI 和 Anthropic API 集成于一个服务。Apple Silicon 原生推理。** + +[![PyPI version](https://img.shields.io/pypi/v/vllm-mlx.svg)](https://pypi.org/project/vllm-mlx/) +[![PyPI Downloads](https://img.shields.io/pypi/dm/vllm-mlx.svg)](https://pypi.org/project/vllm-mlx/) +[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/) +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE) +[![Apple Silicon](https://img.shields.io/badge/Apple-Silicon-black.svg)](https://support.apple.com/en-us/HT211814) +[![GitHub stars](https://img.shields.io/github/stars/waybarrios/vllm-mlx.svg?style=social)](https://github.com/waybarrios/vllm-mlx) + +--- + +## vllm-mlx 是什么? + +面向 Apple Silicon Mac 的 vLLM 风格推理服务器。与直接使用 `Ollama` 或 `mlx-lm` 不同,vllm-mlx 内置了**连续批处理(continuous batching)、分页 KV cache、prefix caching 以及 SSD 分层 KV cache**,并在同一个进程中同时暴露 **OpenAI `/v1/*` 和 Anthropic `/v1/messages`** 接口。可在 Metal 上通过统一内存运行 LLM、视觉模型、音频模型和嵌入模型,无需任何格式转换步骤。 + +## 快速开始(30 秒) + +```bash +pip install vllm-mlx +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous-batching +``` + +**OpenAI SDK:** + +```python +from openai import OpenAI +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") +r = client.chat.completions.create(model="default", messages=[{"role": "user", "content": "你好!"}]) +print(r.choices[0].message.content) +``` + +**Anthropic SDK / Claude Code:** + +```bash +export ANTHROPIC_BASE_URL=http://localhost:8000 +export ANTHROPIC_API_KEY=not-needed +claude +``` + +## 功能特性 + +### API +- **兼容 OpenAI**:`/v1/chat/completions`、`/v1/completions`、`/v1/embeddings`、`/v1/rerank`、`/v1/responses` +- **兼容 Anthropic**:`/v1/messages`(流式、工具调用、system prompts) +- **MCP 工具调用**:12 种解析器(OpenAI、Anthropic、Gemini、Qwen、DeepSeek、Gemma 等) +- **结构化输出**:通过 `response_format` 的 JSON Schema(基于 lm-format-enforcer) + +### 吞吐与内存 +- **连续批处理**:高并发下的高吞吐 +- **分页 KV cache**:内存高效,支持 prefix 共享 +- **SSD 分层 KV cache**:为长上下文 agent 场景将 prefix cache 溢出到磁盘(`--ssd-cache-dir`) +- **Warm prompts**:启动时预加载热门 prefix(`--warm-prompts`),TTFT 提升 1.3-2.25 倍 +- **Prefix cache**:基于 trie,跨请求共享 + +### 多模态 +- **文本 + 图像 + 视频 + 音频** 集成于一个服务 +- 视觉模型:Gemma 3、Gemma 4、Qwen3-VL、Pixtral、Llama vision +- **聊天中的音频输入**(`audio_url` 内容块) +- **原生 TTS**:11 种声音,15+ 种语言(Kokoro、Chatterbox、VibeVoice、VoxCPM) +- **STT**:Whisper 系列,M4 Max 上 RTF 最高可达 197 倍 + +### 推理与高级功能 +- **思维链提取**:Qwen3、DeepSeek-R1(`--reasoning-parser`) +- **MoE 专家裁剪**:`--moe-top-k`,Qwen3-30B-A3B 上 +7-16% +- **投机解码**:`--mtp`,用于 Qwen3-Next +- **稀疏 prefill**:基于注意力的 `--spec-prefill`,降低 TTFT + +### 可观测性 +- **Prometheus 指标**:使用 `--metrics` 开启 `/metrics` 端点 +- **内置基准测试**:`vllm-mlx bench-serve`,支持 prompt 扫描及 CSV/JSON 输出 + +### 原生 GPU 加速 +- 仅支持 Apple Silicon(M1、M2、M3、M4),通过 MLX 使用 Metal kernel +- 统一内存,无需模型转换 + +## 性能 + +**LLM 解码(M4 Max,128 GB,greedy,单流):** + +| 模型 | Tok/s | 内存 | +|------|------:|-----:| +| Qwen3-0.6B-8bit | 417.9 | 0.7 GB | +| Llama-3.2-3B-Instruct-4bit | 205.6 | 1.8 GB | +| Qwen3-30B-A3B-4bit | 127.7 | ~18 GB | + +**音频 speech-to-text(M4 Max,RTF = real-time factor):** + +| 模型 | RTF | 适用场景 | +|------|----:|---------| +| whisper-tiny | 197x | 实时 / 低延迟 | +| whisper-large-v3-turbo | 55x | 质量与速度兼顾 | +| whisper-large-v3 | 24x | 最高精度 | + +完整结果(包括连续批处理、KV cache 量化 4-bit / 8-bit / fp16、MoE top-k 扫描)见 [docs/benchmarks/](docs/benchmarks/)。 + +## 示例 + +### Anthropic API(Claude Code、OpenCode) + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit --port 8000 +export ANTHROPIC_BASE_URL=http://localhost:8000 +export ANTHROPIC_API_KEY=not-needed +claude +``` + +### 推理模型(Qwen3、DeepSeek-R1) + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 +``` + +```python +r = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "17 乘以 23 等于多少?"}], +) +print("思考过程:", r.choices[0].message.reasoning) +print("答案:", r.choices[0].message.content) +``` + +### 多模态(图像 + 文本) + +```bash +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit --port 8000 +``` + +```python +r = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": [ + {"type": "text", "text": "这张图里有什么?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/cat.jpg"}}, + ]}], +) +``` + +### 结构化输出(JSON Schema) + +```python +r = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "列出 3 种颜色。"}], + response_format={ + "type": "json_schema", + "json_schema": { + "schema": {"type": "object", "properties": {"colors": {"type": "array", "items": {"type": "string"}}}} + }, + }, +) +``` + +### 重排序(`/v1/rerank`) + +```bash +curl http://localhost:8000/v1/rerank -H 'Content-Type: application/json' -d '{ + "model": "default", + "query": "apple silicon 推理", + "documents": ["MLX 是苹果的框架", "M 系列芯片上的 Metal kernel", "NVIDIA 上的 CUDA"] +}' +``` + +### 嵌入(Embeddings) + +```bash +vllm-mlx serve --embedding-model mlx-community/all-MiniLM-L6-v2-4bit +``` + +```python +emb = client.embeddings.create(model="mlx-community/all-MiniLM-L6-v2-4bit", input=["你好", "世界"]) +``` + +### 音频(TTS / STT) + +```bash +pip install vllm-mlx[audio] +brew install espeak-ng # macOS,非英语 TTS 需要 + +python examples/tts_example.py "Hello, how are you?" --play +python examples/tts_multilingual.py "你好,世界" --lang zh --play +``` + +### 内置基准测试 + +```bash +vllm-mlx bench-serve --url http://localhost:8000 --concurrency 5 --prompts prompts.txt --output results.csv +``` + +### Prometheus 指标 + +```bash +vllm-mlx serve --metrics +curl http://localhost:8000/metrics +``` + +## 安装 + +**使用 uv(推荐):** + +```bash +uv tool install vllm-mlx # 作为系统级 CLI +# 或在项目中 +uv pip install vllm-mlx +``` + +**使用 pip:** + +```bash +pip install vllm-mlx + +# 音频扩展 +pip install vllm-mlx[audio] +brew install espeak-ng +python -m spacy download en_core_web_sm +``` + +**从源码安装:** + +```bash +git clone https://github.com/waybarrios/vllm-mlx.git +cd vllm-mlx +pip install -e . +``` + +更多选项见 [安装指南](docs/getting-started/installation.md)。 + +## 文档 + +- **入门**:[安装](docs/getting-started/installation.md) · [快速开始](docs/getting-started/quickstart.md) +- **服务器与 API**:[OpenAI 服务器](docs/guides/server.md) · [Anthropic Messages API](docs/guides/server.md#anthropic-messages-api) · [Python API](docs/guides/python-api.md) +- **功能**:[多模态](docs/guides/multimodal.md) · [音频](docs/guides/audio.md) · [嵌入](docs/guides/embeddings.md) · [推理模型](docs/guides/reasoning.md) · [MCP 与工具调用](docs/guides/mcp-tools.md) · [工具解析器](docs/guides/tool-calling.md) +- **性能**:[连续批处理](docs/guides/continuous-batching.md) · [Warm Prompts](docs/guides/warm-prompts.md) · [MoE Top-K](docs/guides/moe-top-k.md) +- **参考**:[CLI](docs/reference/cli.md) · [模型](docs/reference/models.md) · [配置](docs/reference/configuration.md) +- **基准测试**:[LLM](docs/benchmarks/llm.md) · [图像](docs/benchmarks/image.md) · [视频](docs/benchmarks/video.md) · [音频](docs/benchmarks/audio.md) + +## 架构 + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ vllm-mlx 服务器 │ +│ OpenAI /v1/* · Anthropic /v1/messages · /v1/rerank · /metrics │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ 连续批处理 · 分页 KV cache · Prefix cache · SSD 分层 │ +└─────────────────────────────────────────────────────────────────────────┘ + │ + ┌─────────────┬────────────┴────────────┬─────────────┐ + ▼ ▼ ▼ ▼ +┌───────────────┐ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ +│ mlx-lm │ │ mlx-vlm │ │ mlx-audio │ │mlx-embeddings │ +│ (LLMs) │ │ (视觉) │ │ (TTS + STT) │ │ (嵌入向量) │ +└───────────────┘ └───────────────┘ └───────────────┘ └───────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ MLX · Metal kernel · 统一内存 │ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +## 贡献 + +欢迎提交 bug 修复、性能优化、文档改进以及在不同 Apple Silicon 芯片上的 benchmark。详见 [贡献指南](docs/development/contributing.md)。 + +## 许可证 + +Apache 2.0。详见 [LICENSE](LICENSE)。 + +## 引用 + +```bibtex +@software{vllm_mlx2025, + author = {Barrios, Wayner}, + title = {vllm-mlx: Apple Silicon MLX Backend for vLLM}, + year = {2025}, + url = {https://github.com/waybarrios/vllm-mlx}, + note = {Native GPU-accelerated LLM and vision-language model inference on Apple Silicon} +} +``` + +## 致谢 + +- [MLX](https://github.com/ml-explore/mlx)。Apple 的 ML 框架。 +- [mlx-lm](https://github.com/ml-explore/mlx-lm)。LLM 推理库。 +- [mlx-vlm](https://github.com/Blaizzy/mlx-vlm)。视觉语言模型。 +- [mlx-audio](https://github.com/Blaizzy/mlx-audio)。Text-to-Speech 与 Speech-to-Text。 +- [mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings)。文本嵌入。 +- [Rapid-MLX](https://github.com/raullenchai/Rapid-MLX)。vllm-mlx 的社区分支。 +- [vLLM](https://github.com/vllm-project/vllm)。高吞吐 LLM 推理服务。vllm-mlx 受 vLLM 启发,借鉴了其连续批处理和分页 KV cache 的设计,通过 MLX 适配到 Apple Silicon。 + +## Star 历史 + +[![Star History Chart](https://api.star-history.com/svg?repos=waybarrios/vllm-mlx&type=Date)](https://star-history.com/#waybarrios/vllm-mlx&Date) + +--- + +**如果 vllm-mlx 对你有帮助,请给仓库点一个 star。这能帮助更多 Apple Silicon 开发者发现它。** diff --git a/docs/development/architecture.md b/docs/development/architecture.md index cf94d1560..98b6d862b 100644 --- a/docs/development/architecture.md +++ b/docs/development/architecture.md @@ -159,6 +159,39 @@ vllm_mlx/ 6. **Streaming** → SSE response chunks 7. **Caching** → KV cache storage for reuse +## Residency Lifecycle Scope + +The current residency work is scoped around safe automatic unload and reload of the +main model when residency policy decides it should be evicted, including future +memory-pressure-based eviction. It is not intended to turn `load_model()` into a +general hot-reconfiguration API for a running FastAPI server. + +### Residency Invariants + +- Any resident engine swap must invalidate cached tool parser instances. Tool + parsers can retain tokenizer-derived state, so `_tool_parser_instance` must not + survive an unload/reload boundary. +- Residency unload/reload correctness takes priority over in-process + reconfiguration. Once FastAPI lifespan startup has run, changing the main model + or residency policy should be treated as a process restart concern unless the + server explicitly adds and tests live reconfiguration support. + +### Known Limitation + +- `/v1/messages/count_tokens` currently depends on the active engine tokenizer and + may wake the main model even when lazy residency is enabled. In other words, + lazy load defers the first generation-capable request, not necessarily every + request that touches the model. +- Upstream `served_model_name` support has not yet been integrated with + residency. After rebasing, we still need an explicit follow-up to decide how + user-facing model identity should interact with resident specs, `/health`, + `/v1/status`, `/v1/models`, and cache-dir keying for local model paths. +- Eager startup hardening in this branch is currently scoped to cancellation + safety. Ordinary non-cancellation startup failures in `SimpleEngine.start()` + and `BatchedEngine.start()` still need a separate follow-up if we want to + guarantee teardown of partially prepared state before surfacing the original + startup error. + ## Hardware Detection vllm-mlx auto-detects Apple Silicon: diff --git a/docs/es/benchmarks/README.md b/docs/es/benchmarks/README.md new file mode 100644 index 000000000..9e11fb107 --- /dev/null +++ b/docs/es/benchmarks/README.md @@ -0,0 +1,63 @@ +# Benchmarks + +Benchmarks de rendimiento para vllm-mlx en Apple Silicon. + +## Tipos de benchmark + +- [Benchmarks de LLM](llm.md) - Rendimiento de generación de texto +- [Benchmarks de imagen](image.md) - Rendimiento de comprension de imagenes +- [Benchmarks de video](video.md) - Rendimiento de comprension de video + +## Comandos rápidos + +```bash +# LLM benchmark +vllm-mlx-bench --model mlx-community/Qwen3-0.6B-8bit + +# Image benchmark +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit + +# Video benchmark +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video +``` + +## Valores predeterminados de los scripts de prueba + +Los scripts de benchmark independientes tienen modelos predeterminados integrados, por lo que puedes ejecutar: + +```bash +python tests/test_continuous_batching.py +python tests/test_prefix_cache.py +``` + +Valores predeterminados: +- `tests/test_continuous_batching.py` → `mlx-community/Qwen3-8B-6bit` +- `tests/test_prefix_cache.py` → `mlx-community/Qwen3-0.6B-8bit` + +Para probar con otros modelos, usa el parámetro opcional `--model`: + +```bash +python tests/test_continuous_batching.py --model mlx-community/Qwen3-0.6B-8bit +python tests/test_prefix_cache.py --model mlx-community/Qwen3-8B-6bit +``` + +## Hardware + +Los benchmarks se recopilaron en las siguientes configuraciones de Apple Silicon: + +| Chip | Memoria | Python | +|------|---------|--------| +| Apple M4 Max | 128 GB unificada | 3.13 | +| Apple M1 Max | 64 GB unificada | 3.12 | + +Los resultados pueden variar en distintos chips de Apple Silicon. + +## Contribuir benchmarks + +Si tienes un chip de Apple Silicon diferente, comparte tus resultados: + +```bash +vllm-mlx-bench --model mlx-community/Qwen3-0.6B-8bit --output results.json +``` + +Abre un issue con tus resultados en [GitHub Issues](https://github.com/waybarrios/vllm-mlx/issues). diff --git a/docs/es/benchmarks/audio.md b/docs/es/benchmarks/audio.md new file mode 100644 index 000000000..608b4b35f --- /dev/null +++ b/docs/es/benchmarks/audio.md @@ -0,0 +1,158 @@ +# Benchmarks de Audio + +## Benchmarks de Speech-to-Text (STT) + +### Ejecutar benchmarks de STT + +```bash +# Run with default test audio +python examples/benchmark_audio.py --stt + +# Run with your own audio file +python examples/benchmark_audio.py --stt --audio path/to/audio.wav +``` + +### Resultados (M4 Max, 128GB) + +**Audio de prueba:** 46.7 segundos de voz sintetizada + +| Model | Parameters | Load Time | Transcribe Time | RTF* | +|-------|------------|-----------|-----------------|------| +| whisper-tiny | 39M | 0.34s | 0.24s | **197x** | +| whisper-small | 244M | 0.18s | 0.47s | **98x** | +| whisper-medium | 769M | 0.35s | 1.15s | **41x** | +| whisper-large-v3 | 1.5B | 0.50s | 1.96s | **24x** | +| whisper-large-v3-turbo | 809M | 0.12s | 0.86s | **55x** | + +*RTF = Real-Time Factor (mayor es más rápido). Un RTF de 100x significa que 1 minuto de audio se transcribe en aprox. 0.6 segundos.* + +### Resultados (M1 Max, 64GB) + +STT con Parakeet (entorno predeterminado, Whisper no disponible por incompatibilidad de dependencia con numpy): + +| Model | Load Time | Transcribe Time | RTF | +|-------|-----------|-----------------|-----| +| parakeet-tdt-0.6b-v2 | 0.28s | 1.01s | **9.9x** | +| parakeet-tdt-0.6b-v3 | 0.30s | 0.19s | **52.7x** | + +STT con Whisper (`numpy==2.3.5` explícito + `uv run --no-sync`): + +| Model | Load Time | Transcribe Time | RTF | +|-------|-----------|-----------------|-----| +| whisper-tiny | 4.02s | 1.05s | **9.5x** | +| whisper-small | 10.15s | 1.03s | **9.7x** | +| whisper-medium | 22.96s | 2.20s | **4.6x** | +| whisper-large-v3 | 38.34s | 0.96s | **10.5x** | +| whisper-large-v3-turbo | 21.79s | 0.70s | **14.3x** | +| parakeet-tdt-0.6b-v2 | 0.47s | 0.18s | **54.4x** | +| parakeet-tdt-0.6b-v3 | 1.13s | 0.18s | **54.6x** | + +### Recomendaciones de modelos + +| Use Case | Recommended Model | Why | +|----------|-------------------|-----| +| **Transcripcion en tiempo real** | whisper-tiny | El más rápido (197x RTF), baja latencia | +| **Uso general** | whisper-large-v3-turbo | Mejor equilibrio entre velocidad (55x) y calidad | +| **Mayor precision** | whisper-large-v3 | El más preciso, soporta más de 99 idiomas | +| **Memoria reducida** | whisper-small | Buena calidad con 244M parámetros | + +### Calidad de transcripción + +Todos los modelos transcribieron correctamente el audio de prueba. Ejemplo de salida: + +``` +Input text: +"Welcome to this comprehensive speech to text demonstration. +This audio sample is designed to test the accuracy and speed of various speech recognition models. +The quick brown fox jumps over the lazy dog..." + +Whisper-large-v3 output: +"Welcome to this comprehensive speech to text demonstration. +This audio sample is designed to test the accuracy and speed of various speech recognition models. +The quick brown fox jumps over the lazy dog..." (identical) +``` + +### Idiomas soportados + +Los modelos Whisper soportan más de 99 idiomas, entre ellos: +- Inglés, español, francés, alemán, italiano, portugués +- Chino (mandarín, cantonés), japonés, coreano +- Árabe, hindi, ruso, turco, ucraniano +- Y muchos más + +## Benchmarks de Text-to-Speech (TTS) + +### Ejecutar benchmarks de TTS + +```bash +python examples/benchmark_audio.py --tts +``` + +### Resultados (M4 Max, 128GB) + +**Prueba:** Generar audio para 3 muestras de texto (corta, media, larga) + +| Model | Load Time | Chars/sec | RTF* | +|-------|-----------|-----------|------| +| Kokoro-82M-bf16 | 0.8s | 350+ | **22x** | +| Kokoro-82M-4bit | 0.4s | 320+ | **20x** | + +*RTF = Real-Time Factor. Un RTF de 22x significa que 1 segundo de audio se genera en aprox. 0.045 segundos.* + +### Resultados de TTS (M1 Max, 64GB) + +| Model | Load Time | Avg Chars/s | Avg RTF | +|-------|-----------|-------------|---------| +| Kokoro-82M-bf16 | 2.81s | 176.0 | **11.9x** | +| Kokoro-82M-4bit | 0.22s | 225.6 | **15.5x** | + +### Calidad de TTS + +Kokoro produce voz con sonido natural, con: +- 11 voces integradas (masculinas y femeninas) +- Soporte para 8 idiomas (inglés, español, francés, japonés, chino, italiano, portugués, hindi) +- 82M parámetros, rápido y liviano + +## Benchmarks de procesamiento de audio + +### SAM-Audio (separacion de fuentes) + +**Prueba:** Separar la bateria de una cancion de rock de 30 segundos + +| Metric | Value | +|--------|-------| +| Model | sam-audio-large-fp16 | +| Processing time | ~20s | +| Peak memory | ~27 GB | +| Output sample rate | 48000 Hz | + +## Ejecutar todos los benchmarks de audio + +```bash +# Run all benchmarks +python examples/benchmark_audio.py --all + +# Or run individually +python examples/benchmark_audio.py --stt +python examples/benchmark_audio.py --tts +``` + +## Modelos disponibles en mlx-community + +### Modelos STT +- `mlx-community/whisper-tiny-mlx` +- `mlx-community/whisper-small-mlx` +- `mlx-community/whisper-medium-mlx` +- `mlx-community/whisper-large-v3-mlx` +- `mlx-community/whisper-large-v3-turbo` +- `mlx-community/parakeet-tdt-0.6b-v2` +- `mlx-community/parakeet-tdt-0.6b-v3` + +### Modelos TTS +- `mlx-community/Kokoro-82M-bf16` (recommended) +- `mlx-community/Kokoro-82M-4bit` +- `mlx-community/chatterbox-turbo-fp16` +- `mlx-community/VibeVoice-Realtime-0.5B-4bit` + +### Procesamiento de audio +- `mlx-community/sam-audio-large-fp16` diff --git a/docs/es/benchmarks/image.md b/docs/es/benchmarks/image.md new file mode 100644 index 000000000..ff188ca2a --- /dev/null +++ b/docs/es/benchmarks/image.md @@ -0,0 +1,138 @@ +# Benchmarks de Imágenes + +## Ejecutar Benchmarks de Imágenes + +```bash +# Benchmark completo (10 resoluciones) +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit + +# Benchmark rápido (4 resoluciones) +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --quick +``` + +## Resultados - Qwen3-VL-8B-Instruct-4bit (M4 Max, 128GB) + +| Resolution | Pixels | Time | Tokens | Speed | +|------------|--------|------|--------|-------| +| 224x224 | 50K | 1.04s | 78 | 74.8 tok/s | +| 336x336 | 113K | 0.94s | 64 | 68.3 tok/s | +| 448x448 | 201K | 1.45s | 70 | 48.1 tok/s | +| 512x512 | 262K | 1.58s | 99 | 62.8 tok/s | +| 672x672 | 452K | 1.83s | 83 | 45.3 tok/s | +| 768x768 | 590K | 2.05s | 91 | 44.3 tok/s | +| 896x896 | 803K | 2.61s | 90 | 34.5 tok/s | +| 1024x1024 | 1.0M | 2.79s | 76 | 27.2 tok/s | +| 1280x720 | 922K | 2.97s | 96 | 32.4 tok/s | +| 1920x1080 | 2.1M | 6.30s | 89 | 14.1 tok/s | + +**Resumen:** Promedio de 45.2 tok/s en todas las resoluciones. Más rápido en 224x224 (74.8 tok/s), más lento en 1920x1080 (14.1 tok/s) + +## Resultados - Qwen3-VL-8B-Instruct-4bit (M1 Max, 64GB) + +Benchmark MLLM local: + +| Resolution | Pixels | Time | Tokens | Speed | +|------------|--------|------|--------|-------| +| 224x224 | 50K | 1.84s | 78 | 42.5 tok/s | +| 448x448 | 201K | 2.28s | 70 | 30.7 tok/s | +| 768x768 | 590K | 4.39s | 91 | 20.7 tok/s | +| 1024x1024 | 1.0M | 6.41s | 76 | 11.9 tok/s | + +| Configs | Total Time (s) | Total Tokens | Aggregate Tok/s | +|---------|-----------------|--------------|-----------------| +| 4 | 14.92 | 315 | 21.1 | + +## Resultados - Qwen3-VL-4B-Instruct-3bit Server (M1 Max, 64GB) + +| Resolution | Pixels | Time | Tokens | Speed | +|------------|--------|------|--------|-------| +| 224x224 | 50K | 1.65s | 113 | 68.4 tok/s | +| 448x448 | 201K | 2.09s | 120 | 57.5 tok/s | +| 768x768 | 590K | 2.93s | 106 | 36.2 tok/s | +| 1024x1024 | 1.0M | 4.12s | 100 | 24.3 tok/s | + +| Configs | Total Time (s) | Total Tokens | Aggregate Tok/s | +|---------|-----------------|--------------|-----------------| +| 4 | 10.79 | 439 | 40.7 | + +## Resultados del Prefix Cache MLLM + +``` +====================================================================== + MLLM PREFIX CACHE TEST +====================================================================== + Model: mlx-community/Qwen3-VL-4B-Instruct-3bit + Test: Verify KV cache reuse for repeated image/video + prompt combinations + Expected behavior: + - Same image + same prompt → cache HIT + - Same image + different prompt → cache MISS + - Different image + same prompt → cache MISS +---------------------------------------------------------------------- + SETUP: Loading Model +---------------------------------------------------------------------- + Model loaded in 0.11s + +---------------------------------------------------------------------- + SETUP: Creating Test Images +---------------------------------------------------------------------- + Resized: 224x224, 336x336, 512x512, 768x768 + +---------------------------------------------------------------------- + TEST 1: Image Cache - Basic Hit/Miss +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Time | Status + -------+---------------------------+----------+--------+--------+------- + 1a | First image+prompt | MISS | MISS | 0.10ms | ✓ + 1b | Same image+prompt | HIT | HIT | 0.18ms | ✓ + 1c | Different prompt | MISS | MISS | 0.01ms | ✓ + 1d | Return to original | HIT | HIT | 0.18ms | ✓ + +---------------------------------------------------------------------- + TEST 2: Different Images +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Time | Status + -------+---------------------------+----------+--------+--------+------- + 2a | Image A first request | MISS | MISS | 0.01ms | ✓ + 2b | Image B first request | MISS | MISS | 0.01ms | ✓ + 2c | Image A cached | HIT | HIT | 0.13ms | ✓ + +---------------------------------------------------------------------- + TEST 3: Image Resolutions +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Time | Status + -------+-----------------------+----------+--------+--------+------- + 3.1a | 224x224 first | MISS | MISS | 0.01ms | ✓ + 3.1b | 224x224 cached | HIT | HIT | 0.20ms | ✓ + 3.2a | 336x336 first | MISS | MISS | 0.01ms | ✓ + 3.2b | 336x336 cached | HIT | HIT | 0.21ms | ✓ + 3.3a | 512x512 first | MISS | MISS | 0.12ms | ✓ + 3.3b | 512x512 cached | HIT | HIT | 0.20ms | ✓ + 3.4a | 768x768 first | MISS | MISS | 0.12ms | ✓ + 3.4b | 768x768 cached | HIT | HIT | 0.24ms | ✓ +====================================================================== +``` + +## Estrategia de Clave de Cache + +- **Images**: `hash(image_content) + hash(prompt)` + +La misma imagen con el mismo prompt siempre generara un acierto en el cache. Una imagen diferente o un prompt diferente generara un fallo. + +## Consejos de Rendimiento + +- Las resoluciones menores se procesan más rápido (224x224 vs 1920x1080) +- Usa la resolucion adecuada para tu tarea +- Agrupa imagenes de tamanio similar para un rendimiento consistente + +## Referencia de Métricas + +| Metric | Description | +|--------|-------------| +| Resolution | Dimensiones de la imagen (ancho x alto) | +| Pixels | Total pixel count | +| Time | Tiempo de generación | +| Tokens | Tokens de salida generados | +| Speed | Tokens por segundo (tok/s) | diff --git a/docs/es/benchmarks/llm.md b/docs/es/benchmarks/llm.md new file mode 100644 index 000000000..1429107b9 --- /dev/null +++ b/docs/es/benchmarks/llm.md @@ -0,0 +1,271 @@ +# Benchmarks de LLM + +## Ejecutar benchmarks de LLM + +```bash +vllm-mlx-bench --model mlx-community/Llama-3.2-1B-Instruct-4bit --prompts 5 --max-tokens 256 +``` + +## Resultados (M4 Max, 128GB) + +| Model | Gen Speed | TTFT* | Memory | +|-------|-----------|-------|--------| +| Qwen3-0.6B-8bit | 402.3 tok/s | 58.6 ms | 0.68 GB | +| Llama-3.2-1B-Instruct-4bit | 463.6 tok/s | 49.2 ms | 0.69 GB | +| Qwen2.5-1.5B-Instruct-4bit | 308.5 tok/s | 86.2 ms | 0.84 GB | +| Llama-3.2-3B-Instruct-4bit | 200.1 tok/s | 81.4 ms | 1.79 GB | +| Qwen3-30B-A3B-4bit | 123.9 tok/s | 126.9 ms | 16.05 GB | +| NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-6Bit | 122.9 tok/s | 72.3 ms | 23.98 GB | + +*TTFT = Time to First Token (latencia hasta que el modelo comienza a generar) + +## Resultados (M1 Max, 64GB) + +| Model | Runs | Prompt Tok | Gen Tok | Total Time (s) | TTFT Mean (ms) | TPOT Mean (ms) | Gen Speed (tok/s) | Total Throughput (tok/s) | +|-------|------|------------|---------|-----------------|-----------------|-----------------|-------------------|--------------------------| +| Qwen3-0.6B-8bit | 5 | 56 | 1280 | 5.66 | 119.0 | 3.97 | 251.9 | 236.1 | + +## Resultados de continuous batching + +| Model | Single Request | Batch (5 req) | Speedup | +|-------|----------------|---------------|---------| +| Llama-3.2-1B-Instruct-4bit | 299.1 tok/s | 613.0 tok/s | **2.05x** | +| Llama-3.2-3B-Instruct-4bit | 137.6 tok/s | 208.1 tok/s | **1.51x** | +| Qwen3-0.6B-8bit | 328.1 tok/s | 1111.8 tok/s | **3.39x** | +| Qwen3-30B-A3B-4bit | 98.1 tok/s | 233.3 tok/s | **2.38x** | +| Qwen2.5-1.5B-Instruct-4bit | 196.9 tok/s | 322.2 tok/s | **1.64x** | + +*Con 5 solicitudes concurrentes se observa una mejora de throughput de 1.5x a 3x.* + +### Continuous batching (M1 Max, 64GB) + +| Requests | Total Tokens | Total Time (s) | Throughput (tok/s) | Requests/sec | +|----------|--------------|-----------------|--------------------|--------------| +| 5 | 315 | 0.64 | 492.5 | 7.82 | + +## Rendimiento de streaming + +| Model | TTFT | Generation Speed | +|-------|------|------------------| +| Llama-3.2-1B-Instruct-4bit | ~4.6ms | 218.9 tok/s | +| Llama-3.2-3B-Instruct-4bit | ~10.7ms | 93.6 tok/s | +| Qwen3-0.6B-8bit | ~3.0ms | 328.5 tok/s | +| Qwen3-30B-A3B-4bit | ~10.2ms | 98.4 tok/s | +| Qwen2.5-1.5B-Instruct-4bit | ~7.1ms | 140.3 tok/s | + +### Detokenizador en streaming (M1 Max, 64GB) + +`vllm-mlx bench-detok`: + +| Tokens | Iterations | Naive Time | Streaming Time | Speedup | +|--------|------------|------------|----------------|---------| +| 742 | 5 | 1.69ms | 0.71ms | 2.39x | + +`examples/benchmark_detokenizer.py`: + +| Sequence | Tokens | decode() | Streaming | Speedup | +|----------|--------|----------|-----------|---------| +| Short | 8 | 0.029ms | 0.028ms | 1.04x | +| Medium | 103 | 0.206ms | 0.129ms | 1.59x | +| Long | 511 | 1.040ms | 0.502ms | 2.07x | +| 1K | 1191 | 2.446ms | 1.178ms | 2.08x | +| 2K | 2381 | 4.949ms | 2.356ms | 2.10x | +| 4K | 4761 | 9.887ms | 5.398ms | 1.83x | + +Speedup promedio: 1.79x + +## Resultados del prefix cache + +### Prefix cache (M4 Max, 128GB) + +``` +====================================================================== + LLM PREFIX CACHE TEST +====================================================================== + Model: mlx-community/Qwen3-0.6B-8bit + Expected behavior: + - Same prompt → cache HIT + - Different prompt → cache MISS +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Status + -------+---------------------+----------+--------+------- + 1a | First request | MISS | MISS | ✓ + 1b | Same prompt | HIT | HIT | ✓ + 1c | Different prompt | MISS | MISS | ✓ + 1d | Return to prompt 1 | HIT | HIT | ✓ +====================================================================== +``` + +### Prefix cache (M1 Max, 64GB) + +| Test | Expected | Actual | Time | Status | +|------|----------|--------|------|--------| +| First request | MISS | MISS | 203.5ms | PASS | +| Same prompt | HIT | HIT | 131.6ms | PASS | +| Different prompt | MISS or PREFIX_HIT | PREFIX_HIT (5 tok) | 135.3ms | PASS | + +Estadísticas finales del cache: + +| Cache Hits | Cache Misses | Hit Rate | Tokens Saved | Cached Speedup | +|------------|--------------|----------|--------------|----------------| +| 2 | 1 | 66.7% | 20 | 1.55x | + +## Resultados del paged cache + +*Prueba: 20 solicitudes de inferencia reales en 2 rondas con un system prompt compartido de aproximadamente 286 tokens* + +``` +====================================================================== + PAGED KV CACHE - REAL INFERENCE TEST +====================================================================== + +-------------------------------------------------- +Test 1: WITHOUT Paged Cache (2 rounds of 10) +-------------------------------------------------- + Time: 1.47s + Throughput: 681.2 tok/s + Cache hits: 0 + Tokens saved: 0 + +-------------------------------------------------- +Test 2: WITH Paged Cache (2 rounds of 10) +-------------------------------------------------- + Time: 1.31s + Throughput: 765.8 tok/s + + Paged Cache Stats: + Blocks allocated: 25 + Shared blocks: 4 + Cache hits: 10 + Tokens saved: 2560 + +================================================== +SUMMARY +================================================== + Without paged cache: 681.2 tok/s + With paged cache: 765.8 tok/s + + Speedup: 1.12x + Cache hits: 10 (all Round 2 requests) + Tokens saved: 2,560 (~256 tokens × 10 requests) +================================================== +``` + +### Paged KV cache (M1 Max, 64GB) + +Benchmark de inferencia (20 solicitudes): + +| Mode | Time (s) | Throughput (tok/s) | +|------|----------|--------------------| +| Without paged cache | 3.43 | 291.8 | +| With paged cache | 3.42 | 292.2 | + +| Speedup | Blocks Allocated | Shared Blocks | Cache Hits | Tokens Saved | +|---------|------------------|---------------|------------|--------------| +| 1.00x | 45 | 4 | 10 | 2560 | + +Inferencia concurrente real (20 solicitudes): + +| Mode | Time (s) | Throughput (tok/s) | +|------|----------|--------------------| +| Without paged cache | 4.32 | 231.7 | +| With paged cache | 4.35 | 229.7 | + +| Speedup | Blocks Allocated | Shared Blocks | Cache Hits | Tokens Saved | +|---------|------------------|---------------|------------|--------------| +| 0.99x | 49 | 8 | 10 | 5120 | + +Demostración de ahorro de memoria: + +| Scenario | Memory Savings | +|----------|----------------| +| Shared system prompts | 70.8% | +| Concurrent memory efficiency | 83.5% | +| Prefix sharing branches | 38.5% | + +## Análisis del detokenizador en streaming + +*Investigación Fase 9.1: `BPEStreamingDetokenizer` de mlx-lm vs `tokenizer.decode()` naive* + +### Contexto + +El enfoque naive llama a `decode([token])` por cada token. En teoria, los detokenizadores en streaming ofrecen complejidad O(T) frente a O(T²) del decode naive. + +### Resultados del benchmark aislado + +```bash +vllm-mlx bench-detok +``` + +Al reutilizar la misma instancia del detokenizador (con `reset()` entre usos): + +| Sequence | Tokens | Naive decode() | Streaming | Speedup | +|----------|--------|----------------|-----------|---------| +| Short | 8 | 0.020ms | 0.019ms | 1.05x | +| Medium | 103 | 0.155ms | 0.097ms | 1.59x | +| Long | 511 | 0.752ms | 0.371ms | **2.03x** | +| 1K tokens | 1191 | 1.743ms | 0.833ms | **2.09x** | +| 2K tokens | 2381 | 3.493ms | 1.737ms | **2.01x** | + +### Hallazgo clave: costo de creación de instancias + +Crear una nueva instancia de `BPEStreamingDetokenizer` es **extremadamente costoso**: + +``` +100 tokenizer.detokenizer calls: 5.266s (52.7ms each!) +``` + +Esto significa que crear un nuevo detokenizador por solicitud agrega **aproximadamente 52ms de sobrecarga**, anulando cualquier beneficio. + +### Impacto en uso real + +Al integrarlo en el scheduler (un detokenizador por solicitud): + +| Metric | Naive decode() | Streaming (new instance) | +|--------|----------------|--------------------------| +| Throughput (20 req) | 681 tok/s | 275 tok/s | +| Impact | - | **-60% slower** | + +### Conclusión + +El detokenizador en streaming **no es viable actualmente** para uso por solicitud, debido al costo de creación de instancias. El enfoque naive con `decode([token])` sigue siendo más rápido en la práctica. + +**Optimizacion futura**: crear un pool de instancias de detokenizador al inicio y reutilizarlas entre solicitudes. + +## Referencia de métricas + +| Metric | Description | +|--------|-------------| +| **TTFT** | Time to First Token: latencia hasta que el modelo comienza a responder (ms) | +| **TPOT** | Time Per Output Token: tiempo entre cada token generado (ms/token) | +| **Generation TPS** | Tokens de salida por segundo (tok/s) | +| **Processing TPS** | Tokens de entrada/prompt procesados por segundo (tok/s) | +| **End-to-End Latency** | Tiempo total desde la solicitud hasta la respuesta completa | +| **Total Throughput** | Tokens totales (entrada + salida) por segundo | + +## Ejecutar benchmarks + +```bash +# Basic benchmark +vllm-mlx-bench --model mlx-community/Qwen3-0.6B-8bit + +# With more prompts +vllm-mlx-bench --model mlx-community/Qwen3-0.6B-8bit --prompts 10 + +# Save results +vllm-mlx-bench --model mlx-community/Qwen3-0.6B-8bit --output results.json + +# Continuous batching test +python tests/test_continuous_batching.py + +# Prefix cache test +python tests/test_prefix_cache.py + +# Paged cache test +python tests/test_paged_cache_real_inference.py + +# Streaming detokenizer benchmark +vllm-mlx bench-detok +vllm-mlx bench-detok mlx-community/Llama-3.2-1B-Instruct-4bit --iterations 5 +``` diff --git a/docs/es/benchmarks/video.md b/docs/es/benchmarks/video.md new file mode 100644 index 000000000..3623f1f28 --- /dev/null +++ b/docs/es/benchmarks/video.md @@ -0,0 +1,128 @@ +# Benchmarks de Video + +## Ejecutar Benchmarks de Video + +```bash +# Benchmark completo (10 configuraciones, 2-64 fotogramas) +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video + +# Benchmark rápido (3 conteos de fotogramas) +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video --quick + +# Video personalizado +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video --video-url https://example.com/video.mp4 +``` + +## Resultados - Qwen3-VL-8B-Instruct-4bit (M4 Max, 128GB) + +| Configuration | Frames | Time | Tokens | Speed | Memory | +|---------------|--------|------|--------|-------|--------| +| 2 frames @ 0.5fps | 2 | 4.48s | 256 | 57.1 tok/s | 6.4 GB | +| 4 frames @ 1fps | 4 | 4.65s | 256 | 55.0 tok/s | 6.4 GB | +| 6 frames @ 1fps | 6 | 5.15s | 197 | 38.2 tok/s | 6.6 GB | +| 8 frames @ 2fps | 8 | 6.45s | 240 | 37.2 tok/s | 6.8 GB | +| 12 frames @ 2fps | 12 | 8.73s | 256 | 29.3 tok/s | 7.1 GB | +| 16 frames @ 2fps | 16 | 10.96s | 256 | 23.4 tok/s | 7.6 GB | +| 24 frames @ 4fps | 24 | 14.95s | 226 | 15.1 tok/s | 8.4 GB | +| 32 frames @ 4fps | 32 | 20.00s | 256 | 12.8 tok/s | 9.2 GB | +| 48 frames @ 8fps | 48 | 31.11s | 246 | 7.9 tok/s | 11.1 GB | +| 64 frames @ 8fps | 64 | 59.81s | 256 | 4.3 tok/s | 12.9 GB | + +**Resumen:** Más rápido con 2 fotogramas (57.1 tok/s), más lento con 64 fotogramas (4.3 tok/s). La memoria escala de 6.4 GB a 12.9 GB. + +> **Nota:** 96 fotogramas o más provoca un timeout de GPU en la mayoría del hardware debido a los límites de memoria y cómputo. + +## Resultados - Qwen3-VL-8B-Instruct-4bit (M1 Max, 64GB) + +| Configuration | Frames | FPS | Time | Tokens | Speed | +|---------------|--------|-----|------|--------|-------| +| 4 frames @ 1fps | 4 | 1.0 | 8.84s | 256 | 29.0 tok/s | +| 8 frames @ 2fps | 8 | 2.0 | 13.05s | 256 | 19.6 tok/s | +| 16 frames @ 2fps | 16 | 2.0 | 21.60s | 256 | 11.9 tok/s | + +| Configs | Total Time (s) | Total Tokens | Aggregate Tok/s | +|---------|-----------------|--------------|-----------------| +| 3 | 43.48 | 768 | 17.7 | + +## Resultados - Qwen3-VL-4B-Instruct-3bit (M1 Max, 64GB) + +| Configuration | Frames | FPS | Time | Tokens | Speed | +|---------------|--------|-----|------|--------|-------| +| 4 frames @ 1fps | 4 | 1.0 | 5.09s | 150 | 29.5 tok/s | +| 8 frames @ 2fps | 8 | 2.0 | 8.36s | 150 | 17.9 tok/s | +| 16 frames @ 2fps | 16 | 2.0 | 15.21s | 150 | 9.9 tok/s | + +| Configs | Total Time (s) | Total Tokens | Aggregate Tok/s | +|---------|-----------------|--------------|-----------------| +| 3 | 28.66 | 450 | 15.7 | + +## Resultados de Caché de Video + +``` +---------------------------------------------------------------------- + TEST 4: Video Cache - fps/max_frames in Cache Key +---------------------------------------------------------------------- + Config: fps=2.0, max_frames=16 + + Results: + Step | Description | Expected | Actual | Time | Status + -------+---------------------------+----------+--------+--------+------- + 4a | Video first request | MISS | MISS | 0.03ms | ✓ + 4b | Same video+params | HIT | HIT | 0.14ms | ✓ + 4c | Different fps (4.0) | MISS | MISS | 0.01ms | ✓ + 4d | Different max_frames (32) | MISS | MISS | 0.01ms | ✓ + 4.0.5a | fps=0.5 first | MISS | MISS | 0.01ms | ✓ + 4.0.5b | fps=0.5 cached | HIT | HIT | 0.14ms | ✓ + 4.1.0a | fps=1.0 first | MISS | MISS | 0.01ms | ✓ + 4.1.0b | fps=1.0 cached | HIT | HIT | 0.14ms | ✓ + 4.2.0a | fps=2.0 first | MISS | MISS | 0.01ms | ✓ + 4.2.0b | fps=2.0 cached | HIT | HIT | 0.14ms | ✓ + 4.4.0a | fps=4.0 first | MISS | MISS | 0.01ms | ✓ + 4.4.0b | fps=4.0 cached | HIT | HIT | 0.14ms | ✓ + +---------------------------------------------------------------------- + TEST 5: Additional Videos +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Time | Status + -------+---------------------------+----------+--------+--------+------- + 5a | Video 1 first | MISS | MISS | 0.01ms | ✓ + 5b | Video 2 first | MISS | MISS | 0.01ms | ✓ + 5c | Video 1 cached | HIT | HIT | 0.13ms | ✓ + 5d | Video 2 cached | HIT | HIT | 0.13ms | ✓ +``` + +## Estrategia de Clave de Caché + +- **Videos**: `hash(video_path) + hash(fps) + hash(max_frames) + hash(prompt)` + +El mismo video con los mismos valores de fps, max_frames y prompt utilizará la caché. Cambiar cualquier parámetro genera un miss. + +## Consejos de Rendimiento + +- Menor FPS = procesamiento más rápido +- Menos fotogramas = menor uso de memoria +- 64 fotogramas es el máximo práctico +- 96 fotogramas o más provoca un timeout de GPU + +## Extracción de Fotogramas + +| FPS | 10s Video | 30s Video | 60s Video | +|-----|-----------|-----------|-----------| +| 0.5 | 5 frames | 15 frames | 30 frames | +| 1.0 | 10 frames | 30 frames | 60 frames | +| 2.0 | 20 frames | 60 frames | 120 frames* | +| 4.0 | 40 frames | 120 frames* | 240 frames* | + +*Puede alcanzar el límite de `max_frames` + +## Referencia de Métricas + +| Metric | Description | +|--------|-------------| +| Configuration | Configuración de FPS y fotogramas máximos | +| Frames | Fotogramas extraídos realmente | +| Time | Tiempo total de generación | +| Tokens | Tokens de salida generados | +| Speed | Tokens por segundo (tok/s) | +| Memory | Uso de memoria GPU | diff --git a/docs/es/getting-started/installation.md b/docs/es/getting-started/installation.md new file mode 100644 index 000000000..59f25d382 --- /dev/null +++ b/docs/es/getting-started/installation.md @@ -0,0 +1,89 @@ +# Instalacion + +## Requisitos + +- macOS en Apple Silicon (M1/M2/M3/M4) +- Python 3.10+ + +## Instalar con uv (Recomendado) + +```bash +git clone https://github.com/waybarrios/vllm-mlx.git +cd vllm-mlx + +uv pip install -e . +``` + +## Instalar con pip + +```bash +git clone https://github.com/waybarrios/vllm-mlx.git +cd vllm-mlx + +pip install -e . +``` + +### Opcional: Soporte para vision + +Para procesamiento de video con transformers: + +```bash +pip install -e ".[vision]" +``` + +### Opcional: Soporte de audio (STT/TTS) + +```bash +pip install mlx-audio +``` + +### Opcional: Embeddings + +```bash +pip install mlx-embeddings +``` + +## Que se instala + +- `mlx`, `mlx-lm`, `mlx-vlm` - Framework MLX y bibliotecas de modelos +- `transformers`, `tokenizers` - Bibliotecas de HuggingFace +- `opencv-python` - Procesamiento de video +- `gradio` - Interfaz de chat +- `psutil` - Monitoreo de recursos +- `mlx-audio` (opcional) - Speech-to-Text y Text-to-Speech +- `mlx-embeddings` (opcional) - Text embeddings + +## Verificar la instalacion + +```bash +# Verificar comandos CLI +vllm-mlx --help +vllm-mlx-bench --help +vllm-mlx-chat --help + +# Probar con un modelo pequeño +vllm-mlx-bench --model mlx-community/Llama-3.2-1B-Instruct-4bit --prompts 1 +``` + +## Solucion de problemas + +### MLX no encontrado + +Asegurate de estar en Apple Silicon: +```bash +uname -m # Should output "arm64" +``` + +### Fallo en la descarga del modelo + +Verifica tu conexion a internet y el acceso a HuggingFace. Algunos modelos requieren autenticacion: +```bash +huggingface-cli login +``` + +### Sin memoria + +Usa un modelo cuantizado más pequeno: +```bash +vllm-mlx serve mlx-community/Llama-3.2-1B-Instruct-4bit +``` diff --git a/docs/es/getting-started/quickstart.md b/docs/es/getting-started/quickstart.md new file mode 100644 index 000000000..7f6d1a4b5 --- /dev/null +++ b/docs/es/getting-started/quickstart.md @@ -0,0 +1,133 @@ +# Inicio rápido + +## Opción 1: Servidor compatible con OpenAI + +Inicia el servidor: + +```bash +# Simple mode - maximum throughput for single user +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 + +# Continuous batching - for multiple concurrent users +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous-batching +``` + +Úsalo con el SDK de Python de OpenAI: + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +response = client.chat.completions.create( + model="mlx-community/Llama-3.2-3B-Instruct-4bit", + messages=[{"role": "user", "content": "Hello!"}], +) +print(response.choices[0].message.content) +``` + +O con curl: + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"model": "default", "messages": [{"role": "user", "content": "Hello!"}]}' +``` + +## Opción 2: API de Python directa + +```python +from vllm_mlx.models import MLXLanguageModel + +model = MLXLanguageModel("mlx-community/Llama-3.2-3B-Instruct-4bit") +model.load() + +# Generate text +output = model.generate("What is the capital of France?", max_tokens=100) +print(output.text) + +# Streaming +for chunk in model.stream_generate("Tell me a story"): + print(chunk.text, end="", flush=True) +``` + +## Opción 3: Interfaz de chat con Gradio + +```bash +vllm-mlx-chat --served-model-name mlx-community/Llama-3.2-3B-Instruct-4bit +``` + +Abre una interfaz web en http://localhost:7860 + +## Modelos multimodales + +Para comprensión de imágenes y video, usa un modelo VLM: + +```bash +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit --port 8000 +``` + +```python +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}} + ] + }], + max_tokens=256 +) +``` + +## Modelos de razonamiento + +Separa el proceso de pensamiento del modelo de la respuesta final: + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 +``` + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What is 17 × 23?"}] +) +print(response.choices[0].message.content) # Final answer +``` + +## Embeddings + +Genera embeddings de texto para búsqueda semántica y RAG: + +```bash +vllm-mlx serve mlx-community/Qwen3-4B-4bit --embedding-model mlx-community/multilingual-e5-small-mlx +``` + +```python +response = client.embeddings.create( + model="mlx-community/multilingual-e5-small-mlx", + input="Hello world" +) +``` + +## Tool Calling + +Habilita la llamada a funciones con cualquier modelo compatible: + +```bash +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice --tool-call-parser mistral +``` + +## Próximos pasos + +- [Guía del servidor](../guides/server.md) - Configuración completa del servidor +- [API de Python](../guides/python-api.md) - Uso directo de la API +- [Guía multimodal](../guides/multimodal.md) - Imágenes y video +- [Guía de audio](../guides/audio.md) - Speech-to-Text y Text-to-Speech +- [Guía de embeddings](../guides/embeddings.md) - Embeddings de texto +- [Modelos de razonamiento](../guides/reasoning.md) - Modelos con pensamiento +- [Tool Calling](../guides/tool-calling.md) - Llamada a funciones +- [Modelos compatibles](../reference/models.md) - Modelos disponibles diff --git a/docs/es/guides/audio.md b/docs/es/guides/audio.md new file mode 100644 index 000000000..a44e6b1b1 --- /dev/null +++ b/docs/es/guides/audio.md @@ -0,0 +1,524 @@ +# Soporte de Audio + +vllm-mlx soporta el procesamiento de audio mediante [mlx-audio](https://github.com/Blaizzy/mlx-audio), y ofrece: + +- **STT (Speech-to-Text)**: Whisper, Parakeet +- **TTS (Text-to-Speech)**: Kokoro, Chatterbox, VibeVoice, VoxCPM +- **Procesamiento de audio**: SAM-Audio (separación de voz) + +## Instalación + +```bash +# Soporte de audio principal +pip install mlx-audio>=0.2.9 + +# Dependencias requeridas para TTS +pip install sounddevice soundfile scipy numba tiktoken misaki spacy num2words loguru phonemizer + +# Descargar el modelo de inglés de spacy +python -m spacy download en_core_web_sm + +# Para TTS en idiomas distintos al inglés (español, francés, etc.), instalar espeak-ng: +# macOS +brew install espeak-ng + +# Ubuntu/Debian +# sudo apt-get install espeak-ng +``` + +O instalar todas las dependencias de audio de una sola vez: + +```bash +pip install vllm-mlx[audio] +python -m spacy download en_core_web_sm +brew install espeak-ng # macOS, para idiomas distintos al inglés +``` + +## Inicio Rápido + +### Speech-to-Text (Transcripción) + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# Transcribir un archivo de audio +with open("audio.mp3", "rb") as f: + transcript = client.audio.transcriptions.create( + model="whisper-large-v3", + file=f, + language="en" # opcional + ) +print(transcript.text) +``` + +### Text-to-Speech (Generación) + +```python +# Generar voz +audio = client.audio.speech.create( + model="kokoro", + input="Hello, how are you?", + voice="af_heart", + speed=1.0 +) + +# Guardar en archivo +with open("output.wav", "wb") as f: + f.write(audio.content) +``` + +### Separación de Voz (SAM-Audio) + +Aislar la voz del ruido de fondo, música u otros sonidos: + +```python +from vllm_mlx.audio import AudioProcessor + +# Cargar el modelo SAM-Audio +processor = AudioProcessor("mlx-community/sam-audio-large-fp16") +processor.load() + +# Separar el habla del audio +result = processor.separate("meeting_with_music.mp3", description="speech") + +# Guardar la voz aislada y el fondo +processor.save(result.target, "voice_only.wav") +processor.save(result.residual, "background_only.wav") +``` + +**Ejemplo de CLI:** +```bash +python examples/audio_separation_example.py meeting.mp3 --play +python examples/audio_separation_example.py song.mp3 --description music -o music.wav +``` + +### Demo de Separación de Batería + +Aislar la batería de una canción de rock usando SAM-Audio: + +| Audio | Descripción | Escuchar | +|-------|-------------|----------| +| Original | "Get Ready" de David Fesliyan (30s, libre de regalías) | [🎵 rock_get_ready.mp3](../../../examples/rock_get_ready.mp3) | +| Batería aislada | Batería extraída por SAM-Audio | [🥁 drums_isolated.wav](../../../examples/drums_isolated.wav) | +| Sin batería | Pista con la batería eliminada | [🎸 rock_no_drums.wav](../../../examples/rock_no_drums.wav) | + +```bash +# Aislar la batería de una canción de rock +python examples/audio_separation_example.py examples/rock_get_ready.mp3 \ + --description "drums" \ + --output drums_isolated.wav \ + --background rock_no_drums.wav +``` + +**Rendimiento:** 30 segundos de audio procesados en ~20 segundos en M4 Max. + +## Modelos Soportados + +### Modelos STT (Speech-to-Text) + +| Modelo | Alias | Idiomas | Velocidad | Calidad | +|--------|-------|---------|-----------|---------| +| `mlx-community/whisper-large-v3-mlx` | `whisper-large-v3` | 99+ | Media | Mejor | +| `mlx-community/whisper-large-v3-turbo` | `whisper-large-v3-turbo` | 99+ | Rápida | Muy buena | +| `mlx-community/whisper-medium-mlx` | `whisper-medium` | 99+ | Rápida | Buena | +| `mlx-community/whisper-small-mlx` | `whisper-small` | 99+ | Muy rápida | Aceptable | +| `mlx-community/parakeet-tdt-0.6b-v2` | `parakeet` | Inglés | La más rápida | Muy buena | +| `mlx-community/parakeet-tdt-0.6b-v3` | `parakeet-v3` | Inglés | La más rápida | Mejor | + +**Recomendación:** +- Multilingüe: `whisper-large-v3` +- Solo inglés: `parakeet` (3x más rápido) + +### Modelos TTS (Text-to-Speech) + +#### Kokoro (Rápido y ligero) - Recomendado + +| Modelo | Alias | Tamaño | Idiomas | +|--------|-------|--------|---------| +| `mlx-community/Kokoro-82M-bf16` | `kokoro` | 82M | EN, ES, FR, JA, ZH, HI, IT, PT | +| `mlx-community/Kokoro-82M-4bit` | `kokoro-4bit` | 82M | EN, ES, FR, JA, ZH, HI, IT, PT | + +**Voces (11):** +- Femenino estadounidense: `af_heart`, `af_bella`, `af_nicole`, `af_sarah`, `af_sky` +- Masculino estadounidense: `am_adam`, `am_michael` +- Femenino británico: `bf_emma`, `bf_isabella` +- Masculino británico: `bm_george`, `bm_lewis` + +**Códigos de idioma:** +| Código | Idioma | Código | Idioma | +|--------|--------|--------|--------| +| `a` / `en` | English (US) | `e` / `es` | Español | +| `b` / `en-gb` | English (UK) | `f` / `fr` | Français | +| `j` / `ja` | 日本語 | `z` / `zh` | 中文 | +| `i` / `it` | Italiano | `p` / `pt` | Português | +| `h` / `hi` | हिन्दी | | | + +#### Chatterbox (Multilingüe y expresivo) + +| Modelo | Alias | Tamaño | Idiomas | +|--------|-------|--------|---------| +| `mlx-community/chatterbox-turbo-fp16` | `chatterbox` | 134M | 15+ idiomas | +| `mlx-community/chatterbox-turbo-4bit` | `chatterbox-4bit` | 134M | 15+ idiomas | + +**Idiomas soportados:** EN, ES, FR, DE, IT, PT, RU, JA, ZH, KO, AR, HI, NL, PL, TR + +#### VibeVoice (Tiempo real) + +| Modelo | Alias | Tamaño | Caso de uso | +|--------|-------|--------|-------------| +| `mlx-community/VibeVoice-Realtime-0.5B-4bit` | `vibevoice` | 200M | Baja latencia, inglés | + +#### VoxCPM (Chino/Inglés) + +| Modelo | Alias | Tamaño | Idiomas | +|--------|-------|--------|---------| +| `mlx-community/VoxCPM1.5` | `voxcpm` | 0.9B | ZH, EN | +| `mlx-community/VoxCPM1.5-4bit` | `voxcpm-4bit` | 200M | ZH, EN | + +### Modelos de Procesamiento de Audio + +#### SAM-Audio (Separación de Voz) + +| Modelo | Tamaño | Caso de uso | +|--------|--------|-------------| +| `mlx-community/sam-audio-large-fp16` | 3B | Mejor calidad | +| `mlx-community/sam-audio-large` | 3B | Estándar | +| `mlx-community/sam-audio-small-fp16` | 0.6B | Rápido | +| `mlx-community/sam-audio-small` | 0.6B | Ligero | + +## Referencia de API + +### POST /v1/audio/transcriptions + +Transcribir audio a texto (compatible con la API OpenAI Whisper). + +**Parámetros:** +- `file`: Archivo de audio (mp3, wav, m4a, webm) +- `model`: Nombre o alias del modelo +- `language`: Código de idioma (opcional, se detecta automáticamente) +- `response_format`: `json` o `text` + +**Límites:** +- Tamaño máximo de carga por defecto: 25 MiB +- Se puede ajustar con `--max-audio-upload-mb` + +**Ejemplo:** +```bash +curl http://localhost:8000/v1/audio/transcriptions \ + -F file=@audio.mp3 \ + -F model=whisper-large-v3 +``` + +### POST /v1/audio/speech + +Generar voz a partir de texto (compatible con la API OpenAI TTS). + +**Parámetros:** +- `model`: Nombre o alias del modelo +- `input`: Texto a sintetizar +- `voice`: ID de la voz +- `speed`: Velocidad del habla (0.5 a 2.0) +- `response_format`: `wav`, `mp3` + +**Límites:** +- Límite de entrada por defecto: 4096 caracteres +- Se puede ajustar con `--max-tts-input-chars` + +**Ejemplo:** +```bash +curl http://localhost:8000/v1/audio/speech \ + -d '{"model": "kokoro", "input": "Hello world", "voice": "af_heart"}' \ + -H "Content-Type: application/json" \ + --output speech.wav +``` + +### GET /v1/audio/voices + +Listar las voces disponibles para un modelo. + +**Ejemplo:** +```bash +curl http://localhost:8000/v1/audio/voices?model=kokoro +``` + +## Ejemplos de CLI + +### Transcripción en Vivo / Subtítulos + +Transcripción de voz a texto en tiempo real desde el micrófono: + +```bash +# Subtítulos con whisper-large-v3 (mejor calidad) +python examples/closed_captions.py --language es --chunk 5 + +# Modelo más rápido para menor latencia +python examples/closed_captions.py --language en --model whisper-turbo --chunk 3 + +# Transcripción básica por micrófono (grabar y luego transcribir) +python examples/mic_transcribe.py --language es + +# Transcripción en fragmentos en tiempo real +python examples/mic_realtime.py --language es --chunk 3 + +# Transcripción en vivo con detección de actividad de voz +python examples/mic_live.py --language es +``` + +**Requisitos:** +```bash +pip install sounddevice soundfile numpy +``` + +### TTS Básico + +```bash +# Ejemplo simple de TTS +python examples/tts_example.py "Hello, how are you?" --play + +# Con una voz diferente +python examples/tts_example.py "Hello!" --voice am_michael --play + +# Guardar en archivo +python examples/tts_example.py "Welcome to the demo" -o greeting.wav + +# Listar las voces disponibles +python examples/tts_example.py --list-voices +``` + +### TTS Multilingüe + +```bash +# Inglés (selecciona automáticamente el mejor modelo) +python examples/tts_multilingual.py "Hello world" --play + +# Español +python examples/tts_multilingual.py "Hola mundo" --lang es --play + +# Francés +python examples/tts_multilingual.py "Bonjour le monde" --lang fr --play + +# Japonés +python examples/tts_multilingual.py "こんにちは" --lang ja --play + +# Chino +python examples/tts_multilingual.py "你好世界" --lang zh --play + +# Usar un modelo específico +python examples/tts_multilingual.py "Hello" --model chatterbox --play + +# Listar todos los modelos +python examples/tts_multilingual.py --list-models + +# Listar todos los idiomas +python examples/tts_multilingual.py --list-languages +``` + +### Ejemplos de Asistente de Voz para Negocios + +Muestras de voz pregeneradas con **voces nativas** para casos de uso empresariales comunes: + +| Idioma | Voz | Mensaje | Escuchar | +|--------|-----|---------|----------| +| 🇺🇸 Inglés | af_heart | "Welcome to First National Bank. How may I assist you today?" | [▶️ assistant_bank_en.wav](../../../examples/assistant_bank_en.wav) | +| 🇪🇸 Español | ef_dora | "Gracias por llamar a servicio al cliente. Un agente le atenderá pronto." | [▶️ assistant_service_es.wav](../../../examples/assistant_service_es.wav) | +| 🇫🇷 Francés | ff_siwis | "Bienvenue. Votre appel est important pour nous." | [▶️ assistant_callcenter_fr.wav](../../../examples/assistant_callcenter_fr.wav) | +| 🇨🇳 Chino | zf_xiaobei | "欢迎致电技术支持中心。我们将竭诚为您服务。" | [▶️ assistant_support_zh.wav](../../../examples/assistant_support_zh.wav) | + +**Genera tus propias muestras con voces nativas:** +```bash +# Inglés - Asistente bancario (voz nativa: af_heart) +python -m mlx_audio.tts.generate --model mlx-community/Kokoro-82M-bf16 \ + --text "Welcome to First National Bank. How may I assist you today?" \ + --voice af_heart --lang_code a --file_prefix assistant_bank_en + +# Español - Atención al cliente (voz nativa: ef_dora) +python -m mlx_audio.tts.generate --model mlx-community/Kokoro-82M-bf16 \ + --text "Gracias por llamar a servicio al cliente. Un agente le atendera pronto." \ + --voice ef_dora --lang_code e --file_prefix assistant_service_es + +# Francés - Centro de llamadas (voz nativa: ff_siwis) +python -m mlx_audio.tts.generate --model mlx-community/Kokoro-82M-bf16 \ + --text "Bienvenue. Votre appel est important pour nous." \ + --voice ff_siwis --lang_code f --file_prefix assistant_callcenter_fr + +# Chino - Soporte técnico (voz nativa: zf_xiaobei) +python -m mlx_audio.tts.generate --model mlx-community/Kokoro-82M-bf16 \ + --text "欢迎致电技术支持中心。我们将竭诚为您服务。" \ + --voice zf_xiaobei --lang_code z --file_prefix assistant_support_zh +``` + +### Referencia de Voces Nativas + +| Idioma | Código | Voces | +|--------|--------|-------| +| English (US) | `a` | af_heart, af_bella, af_nicole, am_adam, am_michael | +| English (UK) | `b` | bf_emma, bf_isabella, bm_george, bm_lewis | +| Español | `e` | ef_dora, em_alex, em_santa | +| Français | `f` | ff_siwis | +| 中文 | `z` | zf_xiaobei, zf_xiaoni, zf_xiaoxiao, zm_yunjian, zm_yunxi | +| 日本語 | `j` | jf_alpha, jf_gongitsune, jm_kumo | +| Italiano | `i` | if_sara, im_nicola | +| Português | `p` | pf_dora, pm_alex | +| हिन्दी | `h` | hf_alpha, hf_beta, hm_omega | + +## API de Python + +### Uso Directo (sin servidor) + +```python +from vllm_mlx.audio import STTEngine, TTSEngine, AudioProcessor + +# Speech-to-Text +stt = STTEngine("mlx-community/whisper-large-v3-mlx") +stt.load() +result = stt.transcribe("audio.mp3") +print(result.text) + +# Text-to-Speech +tts = TTSEngine("mlx-community/Kokoro-82M-bf16") +tts.load() +audio = tts.generate("Hello world", voice="af_heart") +tts.save(audio, "output.wav") + +# Separación de voz +processor = AudioProcessor("mlx-community/sam-audio-large-fp16") +processor.load() +result = processor.separate("mixed_audio.mp3", description="speech") +processor.save(result.target, "voice_only.wav") +processor.save(result.residual, "background.wav") +``` + +### Funciones de Conveniencia + +```python +from vllm_mlx.audio import transcribe_audio, generate_speech, separate_voice + +# Transcripción rápida +result = transcribe_audio("audio.mp3") +print(result.text) + +# TTS rápido +audio = generate_speech("Hello world", voice="af_heart") + +# Separación de voz rápida +voice, background = separate_voice("mixed.mp3") +``` + +## Audio en el Chat + +Incluir audio en mensajes de chat (se transcribe automáticamente): + +```python +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "Summarize this audio"}, + {"type": "audio_url", "audio_url": {"url": "file://meeting.mp3"}} + ] + }] +) +``` + +## Benchmarks + +Probado en Apple M2 Max (32GB). + +### Benchmarks de TTS (Kokoro-82M-bf16) + +| Longitud del texto | Duración del audio | Tiempo de generación | RTF | Chars/seg | +|--------------------|--------------------|----------------------|-----|-----------| +| 25 chars | 1.95s | 0.43s | 4.6x | 58.5 | +| 88 chars | 6.00s | 0.32s | 18.6x | 272.4 | +| 117 chars | 7.92s | 0.27s | 29.0x | 427.4 | + +**Resumen:** +- Tiempo de carga del modelo: ~1.0s +- RTF promedio: **17.4x** (17 veces más rápido que en tiempo real) +- Chars/seg promedio: **252.8** + +### Benchmarks de STT + +| Modelo | Tiempo de carga | Transcripción (audio de 6s) | RTF | +|--------|-----------------|------------------------------|-----| +| whisper-small | 0.25s | 0.20s | 30.2x | +| whisper-medium | 18.1s | 0.38s | 15.5x | +| whisper-large-v3 | ~30s | ~0.6s | ~10x | +| parakeet | ~0.5s | ~0.15s | ~40x | + +**Notas:** +- RTF (Real-Time Factor) indica cuántas veces más rápido que en tiempo real es el procesamiento +- La primera carga incluye la descarga del modelo desde HuggingFace +- Las cargas siguientes usan los modelos en caché + +### Recomendaciones por Caso de Uso + +| Caso de uso | Modelo recomendado | Motivo | +|-------------|-------------------|--------| +| STT en inglés rápido | `parakeet` | RTF de 40x, bajo consumo de memoria | +| STT multilingüe | `whisper-large-v3` | 99+ idiomas | +| STT de baja latencia | `whisper-small` | RTF de 30x, carga rápida | +| TTS general | `kokoro` | RTF de 17x, buena calidad | +| TTS con poca memoria | `kokoro-4bit` | Cuantizado a 4 bits | + +## Consejos de Rendimiento + +1. **Usa Parakeet para inglés**: 40x más rápido que en tiempo real +2. **Usa modelos de 4 bits** para menor uso de memoria +3. **Usa SAM-Audio small** para una separación de voz más rápida +4. **Guarda los modelos en caché**: los motores se cargan de forma diferida y quedan en caché +5. **Descarga los modelos previamente** para evitar la latencia en la primera ejecución + +## Solución de Problemas + +### mlx-audio no está instalado +``` +pip install mlx-audio>=0.2.9 +``` + +### La descarga del modelo es lenta +Los modelos se descargan desde HuggingFace en el primer uso. Usa `huggingface-cli download` para descargarlos previamente: +```bash +huggingface-cli download mlx-community/whisper-large-v3-mlx +huggingface-cli download mlx-community/Kokoro-82M-bf16 +``` + +### Sin memoria suficiente +Usa modelos más pequeños o versiones cuantizadas a 4 bits: +- `whisper-small-mlx` en lugar de `whisper-large-v3-mlx` +- `Kokoro-82M-4bit` en lugar de `Kokoro-82M-bf16` +- `sam-audio-small` en lugar de `sam-audio-large` + +### Error multilingüe de Kokoro (mlx-audio 0.2.9) + +Si obtienes `ValueError: too many values to unpack` al usar idiomas distintos al inglés (español, chino, japonés, etc.) con Kokoro, aplica esta corrección: + +```python +# Corrección para mlx_audio/tts/models/kokoro/pipeline.py línea 443 +# Cambia: +# ps, _ = self.g2p(chunk) +# Por: +g2p_result = self.g2p(chunk) +ps = g2p_result[0] if isinstance(g2p_result, tuple) else g2p_result +``` + +**Corrección en una sola línea:** +```bash +python -c " +import os +path = os.path.join(os.path.dirname(__import__('mlx_audio').__file__), 'tts/models/kokoro/pipeline.py') +with open(path, 'r') as f: content = f.read() +old = ' ps, _ = self.g2p(chunk)' +new = ''' # Fix: handle both tuple (en) and string (zh/ja/es) returns from g2p + g2p_result = self.g2p(chunk) + ps = g2p_result[0] if isinstance(g2p_result, tuple) else g2p_result''' +if old in content: + with open(path, 'w') as f: f.write(content.replace(old, new)) + print('Fix applied!') +" +``` + +Este error ocurre porque el g2p para inglés devuelve una tupla `(phonemes, tokens)` mientras que otros idiomas devuelven solo una cadena de texto. diff --git a/docs/es/guides/continuous-batching.md b/docs/es/guides/continuous-batching.md new file mode 100644 index 000000000..8fce998c0 --- /dev/null +++ b/docs/es/guides/continuous-batching.md @@ -0,0 +1,174 @@ +# Continuous Batching + +El continuous batching permite mayor throughput al servir múltiples usuarios concurrentes. + +## Activar Continuous Batching + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit --continuous-batching +``` + +## Con Paged Cache + +Para compartir prefijos de forma eficiente en memoria: + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit --continuous-batching --use-paged-cache +``` + +## Cómo Funciona + +### Modo Simple (Predeterminado) +- Una solicitud a la vez +- Máximo throughput para un solo usuario +- Sin sobrecarga por batching + +### Modo Continuous Batching +- Múltiples solicitudes procesadas en conjunto +- Mejor throughput para usuarios concurrentes +- Pequeña sobrecarga por solicitud + +### Paged Cache +- KV cache almacenado en bloques de tamaño fijo +- Los system prompts compartidos usan los mismos bloques +- Ahorro de memoria: 80% o más con 10 o más usuarios concurrentes + +## Resultados de Rendimiento + +**Resultados de Continuous Batching (M4 Max, 128GB):** + +| Modelo | Solicitud Individual | Batch (5 req) | Mejora | +|--------|----------------------|---------------|--------| +| Llama-3.2-1B-Instruct-4bit | 299.1 tok/s | 613.0 tok/s | **2.05x** | +| Llama-3.2-3B-Instruct-4bit | 137.6 tok/s | 208.1 tok/s | **1.51x** | +| Qwen3-0.6B-8bit | 328.1 tok/s | 1111.8 tok/s | **3.39x** | +| Qwen3-30B-A3B-4bit | 98.1 tok/s | 233.3 tok/s | **2.38x** | +| Qwen2.5-1.5B-Instruct-4bit | 196.9 tok/s | 322.2 tok/s | **1.64x** | + +*El batching de 5 solicitudes concurrentes muestra una mejora de throughput de 1.5 a 3 veces.* + +## Rendimiento en Streaming + +**Rendimiento de Streaming (M4 Max, 128GB):** + +| Modelo | TTFT | Velocidad de Generación | +|--------|------|-------------------------| +| Llama-3.2-1B-Instruct-4bit | ~4.6ms | 218.9 tok/s | +| Llama-3.2-3B-Instruct-4bit | ~10.7ms | 93.6 tok/s | +| Qwen3-0.6B-8bit | ~3.0ms | 328.5 tok/s | +| Qwen3-30B-A3B-4bit | ~10.2ms | 98.4 tok/s | +| Qwen2.5-1.5B-Instruct-4bit | ~7.1ms | 140.3 tok/s | + +*TTFT = Time to First Token* + +## Configuración de Streaming + +Controla la entrega de tokens con `--stream-interval`: + +```bash +# Cada token (más fluido) +vllm-mlx serve model --continuous-batching --stream-interval 1 + +# Tokens en batch (mejor para alta latencia) +vllm-mlx serve model --continuous-batching --stream-interval 5 +``` + +| Valor | Comportamiento | +|-------|----------------| +| `1` | Envía cada token de inmediato | +| `2-5` | Agrupa tokens antes de enviar | +| `10+` | Máximo throughput, salida en fragmentos más grandes | + +## Gestión de Memoria + +En modelos grandes, el prefix cache puede consumir una cantidad significativa de memoria. El cache con gestión automática de memoria administra esto de forma transparente: + +```bash +# Detección automática (usa el 20% de la RAM disponible) +vllm-mlx serve model --continuous-batching + +# Límite explícito +vllm-mlx serve model --continuous-batching --cache-memory-mb 2048 + +# Porcentaje personalizado +vllm-mlx serve model --continuous-batching --cache-memory-percent 0.10 +``` + +| Opción | Descripción | +|--------|-------------| +| `--cache-memory-mb` | Establece un límite explícito en MB | +| `--cache-memory-percent` | Fracción de la RAM disponible (predeterminado: 0.20) | +| `--no-memory-aware-cache` | Usa el cache heredado basado en conteo de entradas | + +## Prefix Cache + +El prefix caching reutiliza el KV cache para prompts repetidos. + +### Cómo Funciona + +``` +User 1: System prompt (500 tokens) → Creates 8 blocks +User 2: Same system prompt → Shares 8 blocks (ref_count++) +User N: Same system prompt → Shares 8 blocks (ref_count++) + +Memory savings: 80%+ for 10+ concurrent users +``` + +### Estrategia de Clave de Cache + +- **LLM**: `hash(prompt)` +- **Images**: `hash(image_content) + hash(prompt)` +- **Videos**: `hash(video_path) + hash(fps) + hash(max_frames) + hash(prompt)` + +### Probar el Prefix Cache + +```bash +python tests/test_prefix_cache.py +``` + +``` +====================================================================== + LLM PREFIX CACHE TEST +====================================================================== + Model: mlx-community/Qwen3-0.6B-8bit + Expected behavior: + - Same prompt → cache HIT + - Different prompt → cache MISS or PREFIX_HIT (shared template tokens) +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Status + -------+---------------------+----------+--------+------- + 1a | First request | MISS | MISS | PASS + 1b | Same prompt | HIT | HIT | PASS + 1c | Different prompt | MISS | MISS | PASS + 1d | Return to prompt 1 | HIT | HIT | PASS +====================================================================== +``` + +## Ejecutar Benchmarks + +```bash +# Benchmark de continuous batching +python tests/test_continuous_batching.py + +# Prueba de prefix cache +python tests/test_prefix_cache.py +``` + +## Cuándo Usarlo + +| Escenario | Modo | +|-----------|------| +| Usuario individual, máxima velocidad | Simple (predeterminado) | +| Múltiples usuarios concurrentes | `--continuous-batching` | +| Modelos grandes (7B+) | `--continuous-batching --cache-memory-mb 2048` | +| Producción con prompts compartidos | `--continuous-batching --use-paged-cache` | + +## Configuración para Producción + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching \ + --use-paged-cache \ + --port 8000 +``` diff --git a/docs/es/guides/embeddings.md b/docs/es/guides/embeddings.md new file mode 100644 index 000000000..7e662971e --- /dev/null +++ b/docs/es/guides/embeddings.md @@ -0,0 +1,150 @@ +# Embeddings + +vllm-mlx soporta embeddings de texto usando [mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings), y expone un endpoint `/v1/embeddings` compatible con OpenAI. + +## Instalacion + +```bash +pip install mlx-embeddings>=0.0.5 +``` + +## Inicio rápido + +### Iniciar el servidor con un modelo de embeddings + +```bash +# Precarga un modelo de embeddings especifico al inicio +vllm-mlx serve my-llm-model --embedding-model mlx-community/all-MiniLM-L6-v2-4bit +``` + +Si no se usa `--embedding-model`, el modelo de embeddings se carga de forma diferida en la primera solicitud, pero solo desde la lista de modelos permitidos en tiempo de solicitud. + +### Generar embeddings con el SDK de OpenAI + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# Texto individual +response = client.embeddings.create( + model="mlx-community/all-MiniLM-L6-v2-4bit", + input="Hello world" +) +print(response.data[0].embedding[:5]) # First 5 dimensions + +# Lote de textos +response = client.embeddings.create( + model="mlx-community/all-MiniLM-L6-v2-4bit", + input=[ + "I love machine learning", + "Deep learning is fascinating", + "Natural language processing rocks" + ] +) +for item in response.data: + print(f"Text {item.index}: {len(item.embedding)} dimensions") +``` + +### Usando curl + +```bash +curl http://localhost:8000/v1/embeddings \ + -H "Content-Type: application/json" \ + -d '{ + "model": "mlx-community/all-MiniLM-L6-v2-4bit", + "input": ["Hello world", "How are you?"] + }' +``` + +## Modelos soportados + +Modelos disponibles en tiempo de solicitud: + +| Model | Use Case | Size | +|-------|----------|------| +| `mlx-community/all-MiniLM-L6-v2-4bit` | Fast, compact | Small | +| `mlx-community/embeddinggemma-300m-6bit` | High quality | 300M | +| `mlx-community/bge-large-en-v1.5-4bit` | Best for English | Large | +| `mlx-community/multilingual-e5-small-mlx` | Multilingual retrieval | Small | +| `mlx-community/multilingual-e5-large-mlx` | Multilingual retrieval | Large | +| `mlx-community/bert-base-uncased-mlx` | General BERT baseline | Base | +| `mlx-community/ModernBERT-base-mlx` | ModernBERT baseline | Base | + +Otros modelos de embeddings requieren `--embedding-model` al iniciar el servidor. + +## Gestion de modelos + +### Carga diferida + +Por defecto, el modelo de embeddings se carga en la primera solicitud a `/v1/embeddings`. Es posible cambiar entre los modelos permitidos en tiempo de solicitud, y el modelo anterior se descarga automaticamente. + +### Precarga al inicio + +Usa `--embedding-model` para cargar un modelo al iniciar el servidor. Cuando se establece esta opcion, solo ese modelo puede usarse para embeddings: + +```bash +vllm-mlx serve my-llm-model --embedding-model mlx-community/all-MiniLM-L6-v2-4bit +``` + +Solicitar un modelo diferente devolvera un error 400. + +## Referencia de la API + +### POST /v1/embeddings + +Crea embeddings para los textos de entrada proporcionados. + +**Cuerpo de la solicitud:** + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `model` | string | Yes | Supported embedding model ID, or the startup-pinned model when `--embedding-model` is used | +| `input` | string or list[string] | Yes | Text(s) to embed | + +**Respuesta:** + +```json +{ + "object": "list", + "data": [ + {"object": "embedding", "index": 0, "embedding": [0.023, -0.982, ...]}, + {"object": "embedding", "index": 1, "embedding": [0.112, -0.543, ...]} + ], + "model": "mlx-community/all-MiniLM-L6-v2-4bit", + "usage": {"prompt_tokens": 12, "total_tokens": 12} +} +``` + +## API de Python + +### Uso directo sin servidor + +```python +from vllm_mlx.embedding import EmbeddingEngine + +engine = EmbeddingEngine("mlx-community/all-MiniLM-L6-v2-4bit") +engine.load() + +vectors = engine.embed(["Hello world", "How are you?"]) +print(f"Dimensions: {len(vectors[0])}") + +tokens = engine.count_tokens(["Hello world"]) +print(f"Token count: {tokens}") +``` + +## Solucion de problemas + +### mlx-embeddings no esta instalado + +``` +pip install mlx-embeddings>=0.0.5 +``` + +### Modelo no encontrado + +Asegurate de que el nombre del modelo coincida con alguno de los IDs permitidos en tiempo de solicitud, o inicia el servidor con `--embedding-model` para fijar un modelo personalizado. Puedes descargar los modelos soportados con anticipacion: + +```bash +huggingface-cli download mlx-community/all-MiniLM-L6-v2-4bit +``` diff --git a/docs/es/guides/mcp-tools.md b/docs/es/guides/mcp-tools.md new file mode 100644 index 000000000..6577cf3b9 --- /dev/null +++ b/docs/es/guides/mcp-tools.md @@ -0,0 +1,405 @@ +# MCP y Tool Calling + +vllm-mlx soporta el Model Context Protocol (MCP) para integrar herramientas externas con LLMs. + +## Como funciona el tool calling + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ Tool Calling Flow │ +├─────────────────────────────────────────────────────────────────────┤ +│ │ +│ 1. User Request │ +│ ─────────────────► "List files in /tmp" │ +│ │ +│ 2. LLM Generates Tool Call │ +│ ─────────────────► tool_calls: [{ │ +│ name: "list_directory", │ +│ arguments: {path: "/tmp"} │ +│ }] │ +│ │ +│ 3. App Executes Tool via MCP │ +│ ─────────────────► MCP Server executes list_directory │ +│ Returns: ["file1.txt", "file2.txt"] │ +│ │ +│ 4. Tool Result Sent Back to LLM │ +│ ─────────────────► role: "tool", content: [...] │ +│ │ +│ 5. LLM Generates Final Response │ +│ ─────────────────► "The /tmp directory contains 2 files..." │ +│ │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +## Inicio rápido + +### 1. Crear la configuración de MCP + +Crea el archivo `mcp.json`: + +```json +{ + "mcpServers": { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] + } + } +} +``` + +### 2. Iniciar el servidor con MCP + +```bash +# Modo simple +vllm-mlx serve mlx-community/Qwen3-4B-4bit --mcp-config mcp.json + +# Continuous batching +vllm-mlx serve mlx-community/Qwen3-4B-4bit --mcp-config mcp.json --continuous-batching +``` + +### 3. Verificar el estado de MCP + +```bash +# Verificar estado de MCP +curl http://localhost:8000/v1/mcp/status + +# Listar las herramientas disponibles +curl http://localhost:8000/v1/mcp/tools +``` + +## Ejemplo de tool calling + +```python +import json +import httpx + +BASE_URL = "http://localhost:8000" + +# 1. Get available tools +tools_response = httpx.get(f"{BASE_URL}/v1/mcp/tools") +tools = tools_response.json()["tools"] + +# 2. Send request with tools +response = httpx.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "model": "default", + "messages": [{"role": "user", "content": "List files in /tmp"}], + "tools": tools, + "max_tokens": 1024 + } +) + +result = response.json() +message = result["choices"][0]["message"] + +# 3. Check for tool calls +if message.get("tool_calls"): + tool_call = message["tool_calls"][0] + + # 4. Execute tool via MCP + exec_response = httpx.post( + f"{BASE_URL}/v1/mcp/execute", + json={ + "server": "filesystem", + "tool": tool_call["function"]["name"], + "arguments": json.loads(tool_call["function"]["arguments"]) + } + ) + tool_result = exec_response.json() + + # 5. Send result back to LLM + messages = [ + {"role": "user", "content": "List files in /tmp"}, + message, + { + "role": "tool", + "tool_call_id": tool_call["id"], + "content": json.dumps(tool_result["result"]) + } + ] + + final_response = httpx.post( + f"{BASE_URL}/v1/chat/completions", + json={"model": "default", "messages": messages} + ) + print(final_response.json()["choices"][0]["message"]["content"]) +``` + +## Endpoints de MCP + +| Endpoint | Metodo | Descripcion | +|----------|--------|-------------| +| `/v1/mcp/status` | GET | Verificar el estado de MCP | +| `/v1/mcp/tools` | GET | Listar las herramientas disponibles | +| `/v1/mcp/execute` | POST | Ejecutar una herramienta | + +## Ejemplos de servidores MCP + +### Sistema de archivos + +```json +{ + "mcpServers": { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] + } + } +} +``` + +### GitHub + +```json +{ + "mcpServers": { + "github": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-github"], + "env": { + "GITHUB_TOKEN": "your-token" + } + } + } +} +``` + +### PostgreSQL + +```json +{ + "mcpServers": { + "postgres": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-postgres"], + "env": { + "DATABASE_URL": "postgresql://user:pass@localhost/db" + } + } + } +} +``` + +### Brave Search + +```json +{ + "mcpServers": { + "brave-search": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-brave-search"], + "env": { + "BRAVE_API_KEY": "your-key" + } + } + } +} +``` + +## Multiples servidores MCP + +```json +{ + "mcpServers": { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] + }, + "github": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-github"], + "env": { + "GITHUB_TOKEN": "your-token" + } + } + } +} +``` + +## Chat MCP interactivo + +Para probar MCP de forma interactiva: + +```bash +python examples/mcp_chat.py +``` + +## Formatos de herramientas soportados + +vllm-mlx soporta 12 tool call parsers que cubren todas las familias de modelos principales. Consulta [Tool Calling](tool-calling.md) para ver la lista completa de parsers, alias y ejemplos. + +## Seguridad + +vllm-mlx incluye medidas de seguridad para prevenir ataques de inyeccion de comandos a traves de servidores MCP. + +### Lista blanca de comandos + +Solo se permiten comandos de confianza por defecto: + +| Categoria | Comandos permitidos | +|----------|-----------------| +| Node.js | `npx`, `npm`, `node` | +| Python | `uvx`, `uv`, `python`, `python3`, `pip`, `pipx` | +| Docker | `docker` | +| Servidores MCP | `mcp-server-*` (servidores oficiales) | + +### Patrones bloqueados + +Los siguientes patrones estan bloqueados para prevenir ataques de inyeccion: + +- Encadenamiento de comandos: `;`, `&&`, `||`, `|` +- Sustitucion de comandos: `` ` ``, `$()` +- Recorrido de rutas: `../` +- Variables de entorno peligrosas: `LD_PRELOAD`, `PATH`, `PYTHONPATH` + +### Ejemplo: ataque bloqueado + +```json +{ + "mcpServers": { + "malicious": { + "command": "bash", + "args": ["-c", "rm -rf /"] + } + } +} +``` + +Esta configuración sera rechazada: +``` +ValueError: MCP server 'malicious': Command 'bash' is not in the allowed commands whitelist. +``` + +### Modo de desarrollo (inseguro) + +Solo para desarrollo, es posible omitir la validación de seguridad: + +```json +{ + "mcpServers": { + "custom": { + "command": "my-custom-server", + "skip_security_validation": true + } + } +} +``` + +**ADVERTENCIA**: nunca uses `skip_security_validation` en produccion. + +### Lista blanca personalizada + +Para agregar comandos personalizados a la lista blanca mediante código: + +```python +from vllm_mlx.mcp import MCPCommandValidator, set_validator + +# Add custom commands +validator = MCPCommandValidator( + custom_whitelist={"my-trusted-server", "another-server"} +) +set_validator(validator) +``` + +## Sandboxing de ejecución de herramientas + +Ademas de la validación de comandos, vllm-mlx ofrece sandboxing en tiempo de ejecución para las herramientas: + +### Caracteristicas del sandbox + +| Caracteristica | Descripcion | +|---------|-------------| +| Lista blanca de herramientas | Permite ejecutar solo herramientas especificas | +| Lista negra de herramientas | Bloquea herramientas peligrosas especificas | +| Validacion de argumentos | Bloquea patrones peligrosos en los argumentos de las herramientas | +| Limite de frecuencia | Limita las llamadas a herramientas por minuto | +| Registro de auditoria | Registra todas las ejecuciones de herramientas | + +### Patrones de argumentos bloqueados + +Los argumentos de las herramientas son validados para detectar patrones peligrosos: + +- Recorrido de rutas: `../` +- Directorios del sistema: `/etc/`, `/proc/`, `/sys/` +- Acceso root: `/root/`, `~root` + +### Deteccion de herramientas de alto riesgo + +Las herramientas que coincidan con estos patrones generan advertencias de seguridad: + +- `execute`, `run_command`, `shell`, `eval`, `exec`, `system`, `subprocess` + +### Configuracion personalizada del sandbox + +```python +from vllm_mlx.mcp import ToolSandbox, set_sandbox + +# Create sandbox with custom settings +sandbox = ToolSandbox( + # Only allow specific tools (whitelist mode) + allowed_tools={"read_file", "list_directory"}, + + # Block specific tools (blacklist mode) + blocked_tools={"execute_command", "run_shell"}, + + # Rate limit: max 30 calls per minute + max_calls_per_minute=30, + + # Optional audit callback + audit_callback=lambda audit: print(f"Tool: {audit.tool_name}, Success: {audit.success}"), +) +set_sandbox(sandbox) +``` + +### Acceso a los registros de auditoria + +```python +from vllm_mlx.mcp import get_sandbox + +sandbox = get_sandbox() + +# Get recent audit entries +entries = sandbox.get_audit_log(limit=50) + +# Filter by tool name +file_ops = sandbox.get_audit_log(tool_filter="file") + +# Get only errors +errors = sandbox.get_audit_log(errors_only=True) + +# Clear audit log +sandbox.clear_audit_log() +``` + +### Redaccion de datos sensibles + +Los registros de auditoria redactan automaticamente los campos sensibles (password, token, secret, key, credential, auth) y truncan los valores de gran tamano. + +## Solucion de problemas + +### El servidor MCP no se conecta + +Verifica que el comando del servidor MCP sea correcto: +```bash +npx -y @modelcontextprotocol/server-filesystem /tmp +``` + +### La herramienta no se ejecuta + +Verifica que la herramienta este disponible: +```bash +curl http://localhost:8000/v1/mcp/tools | jq '.tools[].name' +``` + +### La llamada a la herramienta no se analiza + +Asegurate de usar un modelo que soporte llamadas a funciones (Qwen3, Llama-3.2-Instruct). + +### El comando no esta en la lista blanca + +Si ves "Command X is not in the allowed commands whitelist", puedes: +1. Usar un comando permitido (ver lista blanca arriba) +2. Agregar el comando a una lista blanca personalizada +3. Usar `skip_security_validation: true` (solo para desarrollo) diff --git a/docs/es/guides/moe-top-k.md b/docs/es/guides/moe-top-k.md new file mode 100644 index 000000000..cc95c2ad2 --- /dev/null +++ b/docs/es/guides/moe-top-k.md @@ -0,0 +1,125 @@ +# MoE top_k override (`--moe-top-k`) + +Reduce el numero de experts activados por token en modelos Mixture of Experts +como Qwen3-30B-A3B, intercambiando una pequeña cantidad de calidad por un +aumento significativo en el throughput de decodificacion. + +> **Estado:** flag opt-in. El comportamiento por defecto no cambia. Los numeros +> de calidad que se muestran son para Qwen3-30B-A3B-4bit en M4 Max 128 GB. +> Verifica con tu modelo antes de usarlo en cargas de produccion. + +## Que hace + +Qwen3-30B-A3B se entrena con `top_k=8`. cada token selecciona 8 de 128 +experts. En Apple Silicon con batch=1 durante la decodificacion, la multiplicacion +de matrices de experts (`SwitchGLU`) es la parte más costosa del computo por capa, +y ese costo escala de forma aproximadamente lineal con `top_k`. Reducir `top_k` +en tiempo de inferencia ha demostrado (LExI 2025, Lynx 2024) preservar la mayor +parte de la calidad entrenada mientras reduce materialmente el tiempo de +decodificacion. + +`--moe-top-k N` itera cada capa del modelo cargado y, en cada capa que tenga +`.mlp.switch_mlp` (es decir, un bloque sparse-MoE), establece `top_k = N`. Las +capas densas y los modelos densos no se modifican: el flag es un no-op para ellos. + +## Uso + +```bash +# Server +vllm-mlx serve mlx-community/Qwen3-30B-A3B-4bit \ + --continuous-batching \ + --moe-top-k 4 + +# Bench +vllm-mlx bench mlx-community/Qwen3-30B-A3B-4bit --moe-top-k 4 +``` + +El flag se rechaza si `N` es mayor que el `top_k` entrenado del modelo +(solo tiene sentido reducirlo, nunca aumentarlo). + +## Impacto medido + +### Throughput de decodificacion (M4 Max 128 GB, batch=1, greedy) + +| top_k | tok/s | vs baseline | +|---:|---:|---:| +| 8 (baseline) | 126.5 | - | +| 6 | 136.1 | +7.6% | +| 5 | 140.3 | +10.9% | +| 4 | 147.3 | +16.5% | + +### Calidad (Qwen3-30B-A3B-4bit, lm-evaluation-harness, MLX backend) + + + +| top_k | MMLU (acc) | GSM8K (exact match) | Delta vs baseline | +|---:|---:|---:|---:| +| 8 | TBD | TBD | - | +| 6 | TBD | TBD | TBD | +| 5 | TBD | TBD | TBD | +| 4 | TBD | TBD | TBD | + +MMLU: 200 muestras seleccionadas aleatoriamente, 0-shot. +GSM8K: 100 muestras seleccionadas aleatoriamente, 0-shot, exact-match estricto. + +Estos numeros son **indicativos**: los conjuntos de evaluacion completos son +más grandes y desplazarian la precision absoluta, pero no el delta relativo +entre configuraciones de forma significativa. + +### Paridad de salida greedy + +Con `top_k=4` en el checkpoint de 4 bits observamos **los primeros 16 tokens +generados identicos** al baseline en todos los prompts de prueba que usamos. +Esto sugiere que top_k=4 no cambia el argmax en los pasos iniciales de +decodificacion: el modelo es internamente robusto a eliminar la mitad de sus +experts activados. + +Con `top_k=3` o menor, la calidad comenzaria a degradarse de forma visible +(no medido aquí; inferido del paper LExI), por lo que el flag no permite bajar +por debajo de 1 en la capa de validación de configuración. Sin embargo, el +piso recomendado para produccion es `top_k=4`. + +## Cuando usarlo y cuando no + +Usalo cuando: +- Ejecutas un Qwen3 MoE (o compatible: Qwen3.5 MoE, Gemma-MoE) y el throughput + de decodificacion con un solo usuario es tu cuello de botella. +- Tienes una carga de trabajo donde una pequeña perdida de calidad es aceptable + a cambio de una mejora visible en latencia. +- Despliegas en hardware limitado por ancho de banda de memoria (Apple Silicon + serie M) donde el gather de experts domina el tiempo de decodificacion por paso. + +No lo uses cuando: +- Sirves modelos densos: el flag es un no-op y no aporta nada. +- Te importa la precision en el top-1% de suites de evaluacion de leaderboard. +- Ejecutas generaciones largas de chain-of-thought o "modo thinking", donde el + acantilado de calidad puede ser más pronunciado que lo que sugiere MMLU en 0-shot. + +## Combinacion con otras optimizaciones + +Este flag se compone con la cuantizacion. En Qwen3-30B-A3B-4bit nuestra +combinacion medida es: + +- 4-bit + top_k=8: 126.5 tok/s (baseline) +- 4-bit + top_k=4: 147.3 tok/s (+16.5%) +- 3-bit + top_k=8: 138.6 tok/s (+9.6%) +- 3-bit + top_k=6: 147.1 tok/s (+16.3%) . divergencia de calidad medible +- 3-bit + top_k=4: 157.3 tok/s (+24%) . **la calidad de salida se rompe** (el modelo respondió una pregunta diferente en nuestra prueba de humo) + +3-bit + top_k=4 acumulo el error numérico más alla del punto donde el argmax +es estable. Usa a lo sumo un parámetro agresivo: ya sea 4-bit + top_k=4 +o 3-bit + top_k=6. Ambos dan aproximadamente el mismo tok/s (~147) con perfiles +de calidad muy distintos. + +## Internos + +- Helper de parcheo: `vllm_mlx.scheduler.apply_moe_top_k_override(model, k)` +- Se aplica en `Scheduler.__init__` despues de cargar el modelo. +- Tests: `tests/test_moe_top_k.py`. cubre modelos densos, arquitecturas mixtas + y rutas de validación. + +## Referencias + +- LExI: Layer-Adaptive Active Experts, [arXiv 2509.02753](https://arxiv.org/html/2509.02753) +- Not All Experts are Equal (NAEE), [ACL 2024](https://aclanthology.org/2024.acl-long.334.pdf) +- SwiftLM (`SWIFTLM_TOP_K` env knob prior art), [github.com/SharpAI/SwiftLM](https://github.com/SharpAI/SwiftLM) diff --git a/docs/es/guides/multimodal.md b/docs/es/guides/multimodal.md new file mode 100644 index 000000000..33fa571f1 --- /dev/null +++ b/docs/es/guides/multimodal.md @@ -0,0 +1,315 @@ +# Modelos Multimodales (Imágenes y Video) + +vllm-mlx soporta modelos de visión y lenguaje (VLM) para el análisis de imágenes y video. + +## Modelos Soportados + +- Qwen3-VL (recomendado) +- Qwen2-VL +- Gemma 3 +- LLaVA +- Idefics +- PaliGemma +- Pixtral +- Molmo +- DeepSeek-VL + +## Iniciar un Servidor Multimodal + +```bash +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit --port 8000 +``` + +Los modelos que contienen "VL", "Vision" o "mllm" en el nombre se detectan automáticamente como multimodales. + +## Análisis de Imágenes + +### Mediante el SDK de OpenAI + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# Imagen desde URL +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}} + ] + }], + max_tokens=256 +) +print(response.choices[0].message.content) +``` + +### Imágenes en Base64 + +```python +import base64 + +def encode_image(path): + with open(path, "rb") as f: + return base64.b64encode(f.read()).decode("utf-8") + +base64_image = encode_image("photo.jpg") +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} + ] + }] +) +``` + +### Mediante curl + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}} + ] + }], + "max_tokens": 256 + }' +``` + +## Análisis de Video + +### Mediante el SDK de OpenAI + +```python +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What happens in this video?"}, + {"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}} + ] + }], + max_tokens=512 +) +``` + +### Parámetros de Video + +Controla la extracción de fotogramas mediante parámetros adicionales en el cuerpo de la solicitud: + +```python +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "Describe this video"}, + {"type": "video_url", "video_url": {"url": "video.mp4"}} + ] + }], + extra_body={ + "video_fps": 2.0, + "video_max_frames": 32 + } +) +``` + +### Mediante curl + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "Describe this video"}, + {"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}} + ] + }], + "video_fps": 2.0, + "video_max_frames": 16 + }' +``` + +## Formatos Soportados + +### Imágenes + +| Formato | Ejemplo | +|--------|---------| +| URL | `{"type": "image_url", "image_url": {"url": "https://..."}}` | +| Archivo local | `{"type": "image_url", "image_url": {"url": "/path/to/image.jpg"}}` | +| Base64 | `{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}}` | + +### Videos + +| Formato | Ejemplo | +|--------|---------| +| URL | `{"type": "video_url", "video_url": {"url": "https://..."}}` | +| Archivo local | `{"type": "video", "video": "/path/to/video.mp4"}` | +| Base64 | `{"type": "video_url", "video_url": {"url": "data:video/mp4;base64,..."}}` | + +## API de Python + +```python +from vllm_mlx.models import MLXMultimodalLM + +mllm = MLXMultimodalLM("mlx-community/Qwen3-VL-4B-Instruct-3bit") +mllm.load() + +# Imagen +description = mllm.describe_image("photo.jpg") + +# Video +description = mllm.describe_video("video.mp4", fps=2.0) + +# Prompt personalizado +output = mllm.generate( + prompt="Compare these images", + images=["img1.jpg", "img2.jpg"] +) +``` + +## Consejos de Rendimiento + +### Imágenes +- Las resoluciones menores se procesan más rápido (224x224 vs 1920x1080) +- Usa la resolución adecuada para tu tarea + +### Videos +- Menor FPS = procesamiento más rápido +- Menos fotogramas = menor uso de memoria +- 64 fotogramas es el máximo práctico (96 o más causa timeout en la GPU) + +## Benchmarks + +Probado en Apple M4 Max con 128 GB de memoria unificada. + +### Qwen3-VL-4B-Instruct-3bit + +| Resolución | Tiempo | Tokens | Velocidad | Memoria | +|------------|------|--------|-------|--------| +| 224x224 | 0.87s | 124 | 143 tok/s | 2.6 GB | +| 448x448 | 1.01s | 107 | 106 tok/s | 3.1 GB | +| 768x768 | 1.42s | 127 | 89 tok/s | 3.4 GB | +| 1024x1024 | 1.85s | 116 | 63 tok/s | 3.6 GB | + +### Qwen3-VL-8B-Instruct-4bit + +| Resolución | Tiempo | Tokens | Velocidad | Memoria | +|------------|------|--------|-------|--------| +| 224x224 | 1.08s | 78 | 73 tok/s | 5.6 GB | +| 448x448 | 1.41s | 70 | 50 tok/s | 6.1 GB | +| 768x768 | 2.06s | 91 | 44 tok/s | 6.5 GB | +| 1024x1024 | 3.02s | 76 | 25 tok/s | 7.6 GB | + +### Gemma 3 4B 4bit + +| Resolución | Tiempo | Tokens | Velocidad | Memoria | +|------------|------|--------|-------|--------| +| 224x224 | 0.95s | 30 | 32 tok/s | 5.2 GB | +| 448x448 | 0.99s | 34 | 34 tok/s | 5.2 GB | +| 768x768 | 0.99s | 32 | 32 tok/s | 5.2 GB | +| 1024x1024 | 0.95s | 28 | 29 tok/s | 5.2 GB | + +### Ejecutar Benchmarks + +```bash +# Benchmark rápido +vllm-mlx-bench --model mlx-community/Qwen3-VL-4B-Instruct-3bit --quick + +# Benchmark completo con más resoluciones +vllm-mlx-bench --model mlx-community/Qwen3-VL-4B-Instruct-3bit + +# Benchmark de video +vllm-mlx-bench --model mlx-community/Qwen3-VL-4B-Instruct-3bit --video +``` + +## MLLM Cache + +vllm-mlx incluye un sistema de prefix cache para modelos multimodales que puede acelerar significativamente las solicitudes repetidas con las mismas imágenes. + +### Cómo Funciona + +Cuando envías una imagen al modelo, el encoder de visión la procesa y genera embeddings. Este procesamiento toma entre 1 y 2 segundos. El MLLM cache almacena esos embeddings junto con el estado del KV cache, de modo que las solicitudes posteriores con la misma imagen omiten el encoder de visión por completo. + +El cache utiliza hashing basado en contenido (similar a LMCache) para identificar imágenes idénticas sin importar cómo se proporcionen (URL, base64 o ruta de archivo). + +### Habilitar el Cache + +```bash +# Habilitar con configuración predeterminada (512 MB máximo) +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit --enable-mllm-cache + +# Con límite de memoria personalizado +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit \ + --enable-mllm-cache \ + --mllm-cache-max-mb 1024 +``` + +### API de Python + +```python +from vllm_mlx.mllm_cache import MLLMPrefixCacheManager + +# Crear el gestor de cache +cache = MLLMPrefixCacheManager(max_memory_mb=512) + +# Almacenar embeddings y KV cache tras el procesamiento +cache.store( + images=["photo.jpg"], + prompt="Describe this image", + vision_embeddings=embeddings, + kv_cache=kv_state, + num_tokens=128 +) + +# Recuperar del cache en solicitudes posteriores +entry, match_len = cache.fetch(images=["photo.jpg"], prompt="Describe this image") +if entry: + # Usar embeddings en cache, omitir el encoder de visión + embeddings = entry.vision_embeddings + kv_state = entry.kv_cache +``` + +### Estadísticas del Cache + +```python +stats = cache.get_stats() +print(f"Hit rate: {stats.hit_rate:.1%}") +print(f"Memory used: {stats.memory_used_mb:.1f} MB") +print(f"Tokens saved: {stats.tokens_saved}") +``` + +### Gestión de Memoria + +El cache utiliza evicción LRU (Least Recently Used) cuando se alcanza el límite de memoria. Cada entrada registra: + +- Tamaño de los embeddings de visión +- Tamaño del KV cache por capa +- Frecuencia de acceso para el ordenamiento LRU + +Cuando hay presión de memoria, las entradas con acceso menos reciente se evictan primero. + +## Gradio Chat UI + +Para chat multimodal interactivo: + +```bash +vllm-mlx-chat --served-model-name mlx-community/Qwen3-VL-4B-Instruct-3bit +``` + +Soporta arrastrar y soltar imágenes y videos. diff --git a/docs/es/guides/python-api.md b/docs/es/guides/python-api.md new file mode 100644 index 000000000..f1fa04f7e --- /dev/null +++ b/docs/es/guides/python-api.md @@ -0,0 +1,182 @@ +# Python API + +API de Python directa para acceso programático a vllm-mlx. + +## Modelos de lenguaje + +### Uso básico + +```python +from vllm_mlx.models import MLXLanguageModel + +# Load model +model = MLXLanguageModel("mlx-community/Llama-3.2-3B-Instruct-4bit") +model.load() + +# Generate text +output = model.generate("What is the capital of France?", max_tokens=100) +print(output.text) +``` + +### Generación con streaming + +```python +for chunk in model.stream_generate("Tell me a story about a robot"): + print(chunk.text, end="", flush=True) +``` + +### Interfaz de chat + +```python +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, who are you?"} +] +response = model.chat(messages) +print(response.text) +``` + +### Parámetros de generación + +```python +output = model.generate( + prompt="Write a poem", + max_tokens=256, + temperature=0.7, + top_p=0.9, + stop=["END", "\n\n"] +) +``` + +| Parámetro | Descripción | Valor por defecto | +|-----------|-------------|---------| +| `max_tokens` | Cantidad máxima de tokens a generar | 256 | +| `temperature` | Temperatura de muestreo (0-2) | 0.7 | +| `top_p` | Nucleus sampling | 0.9 | +| `stop` | Secuencias de parada | None | + +## Modelos de visión y lenguaje (VLM) + +### Uso básico + +```python +from vllm_mlx.models import MLXMultimodalLM + +# Load model +mllm = MLXMultimodalLM("mlx-community/Qwen3-VL-4B-Instruct-3bit") +mllm.load() + +# Describe an image +description = mllm.describe_image("photo.jpg") +print(description) +``` + +### Preguntas sobre imágenes + +```python +answer = mllm.answer_about_image("photo.jpg", "What color is the car?") +print(answer) +``` + +### Varias imágenes + +```python +output = mllm.generate( + prompt="Compare these two images", + images=["image1.jpg", "image2.jpg"] +) +print(output.text) +``` + +### Comprensión de video + +```python +# From local file +output = mllm.generate( + prompt="What is happening in this video?", + videos=["video.mp4"], + video_fps=2.0, + video_max_frames=16 +) +print(output.text) + +# From URL +output = mllm.generate( + prompt="Describe this video", + videos=["https://example.com/video.mp4"], + video_fps=2.0 +) + +# Convenience method +description = mllm.describe_video("video.mp4", fps=2.0) +``` + +### Parámetros de video + +| Parámetro | Descripción | Valor por defecto | +|-----------|-------------|---------| +| `video_fps` | Fotogramas por segundo a extraer | 2.0 | +| `video_max_frames` | Cantidad máxima de fotogramas a procesar | 32 | + +## Engine API + +Para casos de uso avanzados, se puede usar el engine directamente: + +### Engine simple + +```python +from vllm_mlx.engine import SimpleEngine + +engine = SimpleEngine("mlx-community/Llama-3.2-3B-Instruct-4bit") +await engine.start() + +output = await engine.generate( + prompt="Hello world", + max_tokens=100 +) +print(output.text) + +await engine.stop() +``` + +### Engine con batching + +```python +from vllm_mlx.engine import BatchedEngine + +engine = BatchedEngine("mlx-community/Llama-3.2-3B-Instruct-4bit") +await engine.start() + +# Multiple concurrent requests +output = await engine.generate( + prompt="Hello world", + max_tokens=100 +) + +await engine.stop() +``` + +## Formato de salida + +Todos los métodos de generación retornan un objeto `GenerationOutput`: + +```python +output = model.generate("Hello") + +print(output.text) # Generated text +print(output.prompt_tokens) # Input token count +print(output.completion_tokens) # Output token count +print(output.finish_reason) # "stop" or "length" +``` + +## Manejo de errores + +```python +from vllm_mlx.models import MLXLanguageModel + +try: + model = MLXLanguageModel("invalid-model") + model.load() +except Exception as e: + print(f"Failed to load model: {e}") +``` diff --git a/docs/es/guides/reasoning.md b/docs/es/guides/reasoning.md new file mode 100644 index 000000000..c853773c8 --- /dev/null +++ b/docs/es/guides/reasoning.md @@ -0,0 +1,267 @@ +# Modelos de reasoning + +vllm-mlx admite modelos de reasoning que muestran su proceso de thinking antes de dar una respuesta. Modelos como Qwen3 y DeepSeek-R1 envuelven su reasoning en etiquetas `...`, y vllm-mlx puede analizar estas etiquetas para separar el reasoning de la respuesta final. + +## Por que usar el reasoning parser? + +Cuando un modelo de reasoning genera salida, normalmente luce asi: + +``` + +Let me analyze this step by step. +First, I need to consider the constraints. +The answer should be a prime number less than 10. +Checking: 2, 3, 5, 7 are all prime and less than 10. + +The prime numbers less than 10 are: 2, 3, 5, 7. +``` + +Sin el reasoning parser, obtienes la salida cruda con las etiquetas incluidas. Con el reasoning parser habilitado, el proceso de thinking y la respuesta final se separan en campos distintos dentro de la respuesta de la API. + +## Primeros pasos + +### Iniciar el servidor con el reasoning parser + +```bash +# For Qwen3 models +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 + +# For DeepSeek-R1 models +vllm-mlx serve mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit --reasoning-parser deepseek_r1 +``` + +### Formato de respuesta de la API + +Cuando el reasoning parser esta habilitado, la respuesta de la API incluye un campo `reasoning`: + +**Respuesta sin streaming:** + +```json +{ + "choices": [{ + "message": { + "role": "assistant", + "content": "The prime numbers less than 10 are: 2, 3, 5, 7.", + "reasoning": "Let me analyze this step by step.\nFirst, I need to consider the constraints.\nThe answer should be a prime number less than 10.\nChecking: 2, 3, 5, 7 are all prime and less than 10." + } + }] +} +``` + +**Respuesta con streaming:** + +Los fragmentos se envian por separado para el reasoning y el contenido. Durante la fase de reasoning, los fragmentos tienen `reasoning` con valor. Cuando el modelo pasa a la respuesta final, los fragmentos tienen `content` con valor: + +```json +{"delta": {"reasoning": "Let me analyze"}} +{"delta": {"reasoning": " this step by step."}} +{"delta": {"reasoning": "\nFirst, I need to"}} +... +{"delta": {"content": "The prime"}} +{"delta": {"content": " numbers less than 10"}} +{"delta": {"content": " are: 2, 3, 5, 7."}} +``` + +## Uso con el SDK de OpenAI + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# Non-streaming +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What are the prime numbers less than 10?"}] +) + +message = response.choices[0].message +print("Reasoning:", message.reasoning) # The thinking process +print("Answer:", message.content) # The final answer +``` + +### Streaming con reasoning + +```python +reasoning_text = "" +content_text = "" + +stream = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Solve: 2 + 2 = ?"}], + stream=True +) + +for chunk in stream: + delta = chunk.choices[0].delta + if hasattr(delta, 'reasoning') and delta.reasoning: + reasoning_text += delta.reasoning + print(f"[Thinking] {delta.reasoning}", end="") + if delta.content: + content_text += delta.content + print(delta.content, end="") + +print(f"\n\nFinal reasoning: {reasoning_text}") +print(f"Final answer: {content_text}") +``` + +## Parsers disponibles + +### Parser de Qwen3 (`qwen3`) + +Para modelos Qwen3 que usan etiquetas explicitas `` y ``. + +- Requiere **ambas** etiquetas, la de apertura y la de cierre +- Si faltan las etiquetas, la salida se trata como contenido regular +- Recomendado para: Qwen3-0.6B, Qwen3-4B, Qwen3-8B y modelos similares + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 +``` + +### Parser de DeepSeek-R1 (`deepseek_r1`) + +Para modelos DeepSeek-R1 que pueden omitir la etiqueta de apertura ``. + +- Mas permisivo que el parser de Qwen3 +- Maneja casos donde `` es implicita +- El contenido antes de `` se trata como reasoning incluso sin `` + +```bash +vllm-mlx serve mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit --reasoning-parser deepseek_r1 +``` + +## Como funciona + +El reasoning parser usa deteccion basada en texto para identificar etiquetas de thinking en la salida del modelo. Durante el streaming, rastrea la posicion actual en la salida para enrutar correctamente cada token a `reasoning` o a `content`. + +``` +Model Output: Step 1: analyze...The answer is 42. + ├─────────────────────┤├─────────────────────┤ +Parsed: │ reasoning ││ content │ + └─────────────────────┘└─────────────────────┘ +``` + +El parsing no tiene estado y usa el texto acumulado para determinar el contexto, lo que lo hace robusto para escenarios de streaming donde los tokens pueden llegar en fragmentos arbitrarios. + +## Consejos para mejores resultados + +### Prompting + +Los modelos de reasoning funcionan mejor cuando se les anima a pensar paso a paso: + +```python +messages = [ + {"role": "system", "content": "Think through problems step by step before answering."}, + {"role": "user", "content": "What is 17 × 23?"} +] +``` + +### Manejo del reasoning ausente + +Algunos prompts pueden no activar el reasoning. En esos casos, `reasoning` sera `None` y toda la salida va a `content`: + +```python +message = response.choices[0].message +if message.reasoning: + print(f"Model's thought process: {message.reasoning}") +print(f"Answer: {message.content}") +``` + +### Temperatura y reasoning + +Las temperaturas más bajas tienden a producir patrones de reasoning más consistentes: + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Explain quantum entanglement"}], + temperature=0.3 # More focused reasoning +) +``` + +## Compatibilidad con versiones anteriores + +Cuando no se especifica `--reasoning-parser`, el servidor se comporta como antes: +- Las etiquetas de thinking se incluyen en el campo `content` +- No se agrega el campo `reasoning` a las respuestas + +Esto garantiza que las aplicaciones existentes sigan funcionando sin cambios. + +## Ejemplo: solucionador de problemas matematicos + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +def solve_math(problem: str) -> dict: + """Solve a math problem and return reasoning + answer.""" + response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a math tutor. Show your work."}, + {"role": "user", "content": problem} + ], + temperature=0.2 + ) + + message = response.choices[0].message + return { + "problem": problem, + "work": message.reasoning, + "answer": message.content + } + +result = solve_math("If a train travels 120 km in 2 hours, what is its average speed?") +print(f"Problem: {result['problem']}") +print(f"\nWork shown:\n{result['work']}") +print(f"\nFinal answer: {result['answer']}") +``` + +## Ejemplos con curl + +### Sin streaming + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "What is 15% of 80?"}] + }' +``` + +### Con streaming + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "What is 15% of 80?"}], + "stream": true + }' +``` + +## Solucion de problemas + +### No aparece el campo reasoning en la respuesta + +- Asegurate de haber iniciado el servidor con `--reasoning-parser` +- Verifica que el modelo realmente use etiquetas de thinking (no todos los prompts activan el reasoning) + +### El reasoning aparece en content + +- Es posible que el modelo no este usando el formato de etiquetas esperado +- Prueba un parser diferente (`qwen3` vs `deepseek_r1`) + +### Reasoning truncado + +- Aumenta `--max-tokens` si el modelo esta alcanzando el limite de tokens a mitad del thinking + +## Relacionado + +- [Modelos admitidos](../reference/models.md) - Modelos que admiten reasoning +- [Configuracion del servidor](server.md) - Todas las opciones del servidor +- [Referencia de CLI](../reference/cli.md) - Opciones de línea de comandos diff --git a/docs/es/guides/server.md b/docs/es/guides/server.md new file mode 100644 index 000000000..71e1e6f77 --- /dev/null +++ b/docs/es/guides/server.md @@ -0,0 +1,781 @@ +# Servidor compatible con OpenAI + +vllm-mlx provee un servidor FastAPI con compatibilidad completa con la API de OpenAI. + +Por defecto, el servidor escucha solo en `127.0.0.1`. Usa `--host 0.0.0.0` solo cuando quieras exponerlo fuera de la máquina local de forma intencional. + +## Iniciar el servidor + +### Modo simple (por defecto) + +Máximo rendimiento para un solo usuario: + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 +``` + +### Modo continuous batching + +Para múltiples usuarios concurrentes: + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous-batching +``` + +### Con paged cache + +Caché eficiente en memoria para producción: + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous-batching --use-paged-cache +``` + +## Opciones del servidor + +| Opción | Descripción | Valor por defecto | +|--------|-------------|---------| +| `--port` | Puerto del servidor | 8000 | +| `--host` | Host del servidor | 127.0.0.1 | +| `--api-key` | Clave de API para autenticación | None | +| `--rate-limit` | Solicitudes por minuto por cliente (0 = desactivado) | 0 | +| `--timeout` | Tiempo límite de solicitud en segundos | 300 | +| `--enable-metrics` | Expone métricas de Prometheus en `/metrics` | False | +| `--continuous-batching` | Activa batching para múltiples usuarios | False | +| `--use-paged-cache` | Activa paged KV cache | False | +| `--cache-memory-mb` | Límite de memoria de caché en MB | Auto | +| `--cache-memory-percent` | Fracción de RAM para caché | 0.20 | +| `--max-tokens` | Máximo de tokens por defecto | 32768 | +| `--max-request-tokens` | Máximo de `max_tokens` aceptado de clientes de la API | 32768 | +| `--default-temperature` | Temperatura por defecto cuando no se especifica | None | +| `--default-top-p` | top_p por defecto cuando no se especifica | None | +| `--stream-interval` | Tokens por fragmento de streaming | 1 | +| `--mcp-config` | Ruta al archivo de configuración de MCP | None | +| `--reasoning-parser` | Parser para modelos de reasoning (`qwen3`, `deepseek_r1`) | None | +| `--embedding-model` | Pre-carga un modelo de embeddings al iniciar | None | +| `--enable-auto-tool-choice` | Activa tool calling automático | False | +| `--tool-call-parser` | Parser de tool calls (ver [Tool Calling](tool-calling.md)) | None | + +## Endpoints de la API + +### Chat completions + +```bash +POST /v1/chat/completions +``` + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# Sin streaming +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Hello!"}], + max_tokens=100 +) + +# Con streaming +stream = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Tell me a story"}], + stream=True +) +for chunk in stream: + if chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="") +``` + +### Completions + +```bash +POST /v1/completions +``` + +```python +response = client.completions.create( + model="default", + prompt="The capital of France is", + max_tokens=50 +) +``` + +### Models + +```bash +GET /v1/models +``` + +Retorna los modelos disponibles. + +### Embeddings + +```bash +POST /v1/embeddings +``` + +```python +response = client.embeddings.create( + model="mlx-community/multilingual-e5-small-mlx", + input="Hello world" +) +print(response.data[0].embedding[:5]) # First 5 dimensions +``` + +Consulta la [Guía de Embeddings](embeddings.md) para más detalles. + +### Health check + +```bash +GET /health +``` + +Retorna el estado del servidor. + +### Métricas + +```bash +GET /metrics +``` + +Endpoint de scrape de Prometheus con métricas del servidor, caché, scheduler y solicitudes. +El endpoint está desactivado por defecto y se habilita con `--enable-metrics`. + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit \ + --enable-metrics +``` + +`/metrics` no requiere autenticación de forma intencional. Exponlo solo en una red de confianza o detrás de un proxy inverso o firewall que limite quién puede consultarlo. + +### API de mensajes de Anthropic + +```bash +POST /v1/messages +``` + +Endpoint compatible con Anthropic que permite que herramientas como Claude Code y OpenCode se conecten directamente a vllm-mlx. Internamente traduce las solicitudes de Anthropic al formato de OpenAI, ejecuta la inferencia a través del motor y convierte la respuesta de vuelta al formato de Anthropic. + +Capacidades: +- Respuestas sin streaming y con streaming (SSE) +- Mensajes de sistema (cadena de texto simple o lista de bloques de contenido) +- Conversaciones multi-turno con mensajes de usuario y asistente +- Tool calling con bloques de contenido `tool_use` / `tool_result` +- Conteo de tokens para seguimiento de presupuesto +- Contenido multimodal (imágenes mediante bloques `source`) +- Detección de desconexión del cliente (retorna HTTP 499) +- Filtrado automático de tokens especiales en la salida en streaming + +#### Sin streaming + +```python +from anthropic import Anthropic + +client = Anthropic(base_url="http://localhost:8000", api_key="not-needed") + +response = client.messages.create( + model="default", + max_tokens=256, + messages=[{"role": "user", "content": "Hello!"}] +) +print(response.content[0].text) +# Response includes: response.id, response.model, response.stop_reason, +# response.usage.input_tokens, response.usage.output_tokens +``` + +#### Streaming + +El streaming sigue el protocolo de eventos SSE de Anthropic. Los eventos se emiten en este orden: +`message_start` -> `content_block_start` -> `content_block_delta` (repetido) -> `content_block_stop` -> `message_delta` -> `message_stop` + +```python +with client.messages.stream( + model="default", + max_tokens=256, + messages=[{"role": "user", "content": "Tell me a story"}] +) as stream: + for text in stream.text_stream: + print(text, end="") +``` + +#### Mensajes de sistema + +Los mensajes de sistema pueden ser una cadena de texto simple o una lista de bloques de contenido: + +```python +# Plain string +response = client.messages.create( + model="default", + max_tokens=256, + system="You are a helpful coding assistant.", + messages=[{"role": "user", "content": "Write a hello world in Python"}] +) + +# List of content blocks +response = client.messages.create( + model="default", + max_tokens=256, + system=[ + {"type": "text", "text": "You are a helpful assistant."}, + {"type": "text", "text": "Be concise in your answers."}, + ], + messages=[{"role": "user", "content": "What is 2+2?"}] +) +``` + +#### Tool calling + +Define las herramientas con `name`, `description` e `input_schema`. El modelo retorna bloques de contenido `tool_use` cuando desea llamar a una herramienta. Envía los resultados de vuelta como bloques `tool_result`. + +```python +# Step 1: Send request with tools +response = client.messages.create( + model="default", + max_tokens=1024, + messages=[{"role": "user", "content": "What's the weather in Paris?"}], + tools=[{ + "name": "get_weather", + "description": "Get weather for a city", + "input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + } + }] +) + +# Step 2: Check if model wants to use tools +for block in response.content: + if block.type == "tool_use": + print(f"Tool: {block.name}, Input: {block.input}, ID: {block.id}") + # response.stop_reason will be "tool_use" + +# Step 3: Send tool result back +response = client.messages.create( + model="default", + max_tokens=1024, + messages=[ + {"role": "user", "content": "What's the weather in Paris?"}, + {"role": "assistant", "content": response.content}, + {"role": "user", "content": [ + { + "type": "tool_result", + "tool_use_id": block.id, + "content": "Sunny, 22C" + } + ]} + ], + tools=[{ + "name": "get_weather", + "description": "Get weather for a city", + "input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + } + }] +) +print(response.content[0].text) # "The weather in Paris is sunny, 22C." +``` + +Modos de selección de herramientas: + +| `tool_choice` | Comportamiento | +|---------------|----------| +| `{"type": "auto"}` | El modelo decide si llamar herramientas (por defecto) | +| `{"type": "any"}` | El modelo debe llamar al menos una herramienta | +| `{"type": "tool", "name": "get_weather"}` | El modelo debe llamar la herramienta especificada | +| `{"type": "none"}` | El modelo no llamará ninguna herramienta | + +#### Conversaciones multi-turno + +```python +messages = [ + {"role": "user", "content": "My name is Alice."}, + {"role": "assistant", "content": "Nice to meet you, Alice!"}, + {"role": "user", "content": "What's my name?"}, +] + +response = client.messages.create( + model="default", + max_tokens=100, + messages=messages +) +``` + +#### Conteo de tokens + +```bash +POST /v1/messages/count_tokens +``` + +Cuenta los tokens de entrada para una solicitud de Anthropic usando el tokenizador del modelo. Útil para el seguimiento de presupuesto antes de enviar una solicitud. Cuenta tokens de mensajes de sistema, mensajes de conversación, entradas de tool_use, contenido de tool_result y definiciones de herramientas (name, description, input_schema). + +```python +import requests + +resp = requests.post("http://localhost:8000/v1/messages/count_tokens", json={ + "model": "default", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "system": "You are helpful.", + "tools": [{ + "name": "search", + "description": "Search the web", + "input_schema": {"type": "object", "properties": {"q": {"type": "string"}}} + }] +}) +print(resp.json()) # {"input_tokens": 42} +``` + +#### Ejemplos con curl + +Sin streaming: + +```bash +curl http://localhost:8000/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +Con streaming: + +```bash +curl http://localhost:8000/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "max_tokens": 256, + "stream": true, + "messages": [{"role": "user", "content": "Tell me a joke"}] + }' +``` + +Conteo de tokens: + +```bash +curl http://localhost:8000/v1/messages/count_tokens \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "Hello!"}] + }' +# {"input_tokens": 12} +``` + +#### Campos de la solicitud + +| Campo | Tipo | Requerido | Valor por defecto | Descripción | +|-------|------|----------|---------|-------------| +| `model` | string | sí | - | Nombre del modelo (usa `"default"` para el modelo cargado) | +| `messages` | list | sí | - | Mensajes de conversación con `role` y `content` | +| `max_tokens` | int | sí | - | Número máximo de tokens a generar | +| `system` | string o list | no | null | Prompt de sistema (cadena o lista de bloques `{"type": "text", "text": "..."}`) | +| `stream` | bool | no | false | Activa el streaming SSE | +| `temperature` | float | no | 0.7 | Temperatura de muestreo (0.0 = determinista, 1.0 = creativo) | +| `top_p` | float | no | 0.9 | Umbral de nucleus sampling | +| `top_k` | int | no | null | Top-k sampling | +| `stop_sequences` | list | no | null | Secuencias que detienen la generación | +| `tools` | list | no | null | Definiciones de herramientas con `name`, `description`, `input_schema` | +| `tool_choice` | dict | no | null | Modo de selección de herramientas (`auto`, `any`, `tool`, `none`) | +| `metadata` | dict | no | null | Metadatos arbitrarios (se pasan sin ser usados por el servidor) | + +#### Formato de respuesta + +Respuesta sin streaming: + +```json +{ + "id": "msg_abc123...", + "type": "message", + "role": "assistant", + "model": "default", + "content": [ + {"type": "text", "text": "Hello! How can I help?"} + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 12, + "output_tokens": 8 + } +} +``` + +Cuando se llaman herramientas, `content` incluye bloques `tool_use` y `stop_reason` es `"tool_use"`: + +```json +{ + "content": [ + {"type": "text", "text": "Let me check the weather."}, + { + "type": "tool_use", + "id": "call_abc123", + "name": "get_weather", + "input": {"city": "Paris"} + } + ], + "stop_reason": "tool_use" +} +``` + +Razones de parada: + +| `stop_reason` | Significado | +|---------------|---------| +| `end_turn` | El modelo terminó de forma natural | +| `tool_use` | El modelo quiere llamar una herramienta | +| `max_tokens` | Se alcanzó el límite de `max_tokens` | + +#### Uso con Claude Code + +Apunta Claude Code directamente a tu servidor vllm-mlx: + +```bash +# Start the server +vllm-mlx serve mlx-community/Qwen3-Coder-Next-235B-A22B-4bit \ + --continuous-batching \ + --enable-auto-tool-choice \ + --tool-call-parser hermes + +# In another terminal, configure Claude Code +export ANTHROPIC_BASE_URL=http://localhost:8000 +export ANTHROPIC_API_KEY=not-needed +claude +``` + +### Estado del servidor + +```bash +GET /v1/status +``` + +Endpoint de monitoreo en tiempo real que retorna estadísticas generales del servidor y detalles por solicitud. Útil para depurar el rendimiento, rastrear la eficiencia de la caché y monitorear la memoria GPU Metal. + +```bash +curl -s http://localhost:8000/v1/status | python -m json.tool +``` + +Respuesta de ejemplo: + +```json +{ + "status": "running", + "model": "mlx-community/Qwen3-8B-4bit", + "uptime_s": 342.5, + "steps_executed": 1247, + "num_running": 1, + "num_waiting": 0, + "total_requests_processed": 15, + "total_prompt_tokens": 28450, + "total_completion_tokens": 3200, + "metal": { + "active_memory_gb": 5.2, + "peak_memory_gb": 8.1, + "cache_memory_gb": 2.3 + }, + "cache": { + "type": "memory_aware_cache", + "entries": 5, + "hit_rate": 0.87, + "memory_mb": 2350 + }, + "requests": [ + { + "request_id": "req_abc123", + "phase": "generation", + "tokens_per_second": 45.2, + "ttft_s": 0.8, + "progress": 0.35, + "cache_hit_type": "prefix", + "cached_tokens": 1200, + "generated_tokens": 85, + "max_tokens": 256 + } + ] +} +``` + +Campos de la respuesta: + +| Campo | Descripción | +|-------|-------------| +| `status` | Estado del servidor: `running`, `stopped` o `not_loaded` | +| `model` | Nombre del modelo cargado | +| `uptime_s` | Segundos desde que el servidor inició | +| `steps_executed` | Total de pasos de inferencia ejecutados | +| `num_running` | Número de solicitudes generando tokens actualmente | +| `num_waiting` | Número de solicitudes en cola para prefill | +| `total_requests_processed` | Total de solicitudes completadas desde el inicio | +| `total_prompt_tokens` | Total de tokens de prompt procesados desde el inicio | +| `total_completion_tokens` | Total de tokens de completion generados desde el inicio | +| `metal.active_memory_gb` | Memoria GPU Metal en uso actualmente (GB) | +| `metal.peak_memory_gb` | Uso pico de memoria GPU Metal (GB) | +| `metal.cache_memory_gb` | Uso de memoria de caché Metal (GB) | +| `cache` | Estadísticas de caché (tipo, entradas, tasa de aciertos, uso de memoria) | +| `requests` | Lista de solicitudes activas con detalles por solicitud | + +Campos por solicitud en `requests`: + +| Campo | Descripción | +|-------|-------------| +| `request_id` | Identificador único de la solicitud | +| `phase` | Fase actual: `queued`, `prefill` o `generation` | +| `tokens_per_second` | Rendimiento de generación para esta solicitud | +| `ttft_s` | Tiempo hasta el primer token (segundos) | +| `progress` | Porcentaje de completado (0.0 a 1.0) | +| `cache_hit_type` | Tipo de coincidencia en caché: `exact`, `prefix`, `supersequence`, `lcp` o `miss` | +| `cached_tokens` | Número de tokens servidos desde caché | +| `generated_tokens` | Tokens generados hasta ahora | +| `max_tokens` | Máximo de tokens solicitados | + +## Tool Calling + +Activa tool calling compatible con OpenAI con `--enable-auto-tool-choice`: + +```bash +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice \ + --tool-call-parser mistral +``` + +Usa la opción `--tool-call-parser` para seleccionar el parser adecuado para tu modelo: + +| Parser | Modelos | +|--------|--------| +| `auto` | Detección automática (prueba todos los parsers) | +| `mistral` | Mistral, Devstral | +| `qwen` | Qwen, Qwen3 | +| `llama` | Llama 3.x, 4.x | +| `hermes` | Hermes, NousResearch | +| `deepseek` | DeepSeek V3, R1 | +| `kimi` | Kimi K2, Moonshot | +| `granite` | IBM Granite 3.x, 4.x | +| `nemotron` | NVIDIA Nemotron | +| `xlam` | Salesforce xLAM | +| `functionary` | MeetKai Functionary | +| `glm47` | GLM-4.7, GLM-4.7-Flash | + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What's the weather in Paris?"}], + tools=[{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + } + } + }] +) + +if response.choices[0].message.tool_calls: + for tc in response.choices[0].message.tool_calls: + print(f"{tc.function.name}: {tc.function.arguments}") +``` + +Consulta la [Guía de Tool Calling](tool-calling.md) para la documentación completa. + +## Modelos de reasoning + +Para modelos que muestran su proceso de pensamiento (Qwen3, DeepSeek-R1), usa `--reasoning-parser` para separar el reasoning de la respuesta final: + +```bash +# Qwen3 models +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 + +# DeepSeek-R1 models +vllm-mlx serve mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit --reasoning-parser deepseek_r1 +``` + +La respuesta de la API incluye un campo `reasoning` con el proceso de pensamiento del modelo: + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What is 17 × 23?"}] +) + +print(response.choices[0].message.reasoning) # Step-by-step thinking +print(response.choices[0].message.content) # Final answer +``` + +En streaming, los fragmentos de reasoning llegan primero, seguidos de los fragmentos de contenido: + +```python +for chunk in stream: + delta = chunk.choices[0].delta + if delta.reasoning: + print(f"[Thinking] {delta.reasoning}") + if delta.content: + print(delta.content, end="") +``` + +Consulta la [Guía de Modelos de Reasoning](reasoning.md) para todos los detalles. + +## Salida estructurada (modo JSON) + +Obliga al modelo a retornar JSON válido usando `response_format`: + +### Modo JSON Object + +Retorna cualquier JSON válido: + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "List 3 colors"}], + response_format={"type": "json_object"} +) +# Output: {"colors": ["red", "blue", "green"]} +``` + +### Modo JSON Schema + +Retorna JSON que coincide con un esquema específico: + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "List 3 colors"}], + response_format={ + "type": "json_schema", + "json_schema": { + "name": "colors", + "schema": { + "type": "object", + "properties": { + "colors": { + "type": "array", + "items": {"type": "string"} + } + }, + "required": ["colors"] + } + } + } +) +# Output validated against schema +data = json.loads(response.choices[0].message.content) +assert "colors" in data +``` + +### Ejemplo con curl + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "List 3 colors"}], + "response_format": {"type": "json_object"} + }' +``` + +## Ejemplos con curl + +### Chat + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "Hello!"}], + "max_tokens": 100 + }' +``` + +### Streaming + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "Hello!"}], + "stream": true + }' +``` + +## Configuración de streaming + +Controla el comportamiento del streaming con `--stream-interval`: + +| Valor | Comportamiento | +|-------|----------| +| `1` (por defecto) | Envía cada token inmediatamente | +| `2-5` | Agrupa tokens antes de enviar | +| `10+` | Máximo rendimiento, salida en fragmentos más grandes | + +```bash +# Smooth streaming +vllm-mlx serve model --continuous-batching --stream-interval 1 + +# Batched streaming (better for high-latency networks) +vllm-mlx serve model --continuous-batching --stream-interval 5 +``` + +## Integración con Open WebUI + +```bash +# 1. Start vllm-mlx server +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 + +# 2. Start Open WebUI +docker run -d -p 3000:8080 \ + -e OPENAI_API_BASE_URL=http://host.docker.internal:8000/v1 \ + -e OPENAI_API_KEY=not-needed \ + --name open-webui \ + ghcr.io/open-webui/open-webui:main + +# 3. Open http://localhost:3000 +``` + +## Despliegue en producción + +### Con systemd + +Crea `/etc/systemd/system/vllm-mlx.service`: + +```ini +[Unit] +Description=vLLM-MLX Server +After=network.target + +[Service] +Type=simple +ExecStart=/usr/local/bin/vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching --use-paged-cache --port 8000 +Restart=always + +[Install] +WantedBy=multi-user.target +``` + +```bash +sudo systemctl enable vllm-mlx +sudo systemctl start vllm-mlx +``` + +### Configuración recomendada + +Para producción con 50 o más usuarios concurrentes: + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching \ + --use-paged-cache \ + --api-key your-secret-key \ + --rate-limit 60 \ + --timeout 120 \ + --port 8000 +``` diff --git a/docs/es/guides/tool-calling.md b/docs/es/guides/tool-calling.md new file mode 100644 index 000000000..970909467 --- /dev/null +++ b/docs/es/guides/tool-calling.md @@ -0,0 +1,244 @@ +# Tool Calling + +vllm-mlx soporta tool calling compatible con OpenAI (function calling) con análisis automático para muchas familias de modelos populares. + +## Inicio rápido + +Activa el tool calling agregando la bandera `--enable-auto-tool-choice` al iniciar el servidor: + +```bash +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice \ + --tool-call-parser mistral +``` + +Luego usa herramientas con la API estándar de OpenAI: + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What's the weather in Paris?"}], + tools=[{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"} + }, + "required": ["city"] + } + } + }] +) + +# Check for tool calls +if response.choices[0].message.tool_calls: + for tc in response.choices[0].message.tool_calls: + print(f"Function: {tc.function.name}") + print(f"Arguments: {tc.function.arguments}") +``` + +## Parsers disponibles + +Usa `--tool-call-parser` para seleccionar un tool parser según tu familia de modelos: + +| Parser | Alias | Modelos | Formato | +|--------|-------|---------|---------| +| `auto` | | Cualquier modelo | Detecta el formato automáticamente (prueba todos los parsers) | +| `mistral` | | Mistral, Devstral | Arreglo JSON con `[TOOL_CALLS]` | +| `qwen` | `qwen3` | Qwen, Qwen3 | XML `` o `[Calling tool:]` | +| `llama` | `llama3`, `llama4` | Llama 3.x, 4.x | Etiquetas `` | +| `hermes` | `nous` | Hermes, NousResearch | JSON `` dentro de XML | +| `deepseek` | `deepseek_v3`, `deepseek_r1` | DeepSeek V3, R1 | Delimitadores Unicode | +| `kimi` | `kimi_k2`, `moonshot` | Kimi K2, Moonshot | Tokens `<\|tool_call_begin\|>` | +| `granite` | `granite3` | IBM Granite 3.x, 4.x | `<\|tool_call\|>` o `` | +| `nemotron` | `nemotron3` | NVIDIA Nemotron | `` | +| `xlam` | | Salesforce xLAM | JSON con arreglo `tool_calls` | +| `functionary` | `meetkai` | MeetKai Functionary | Múltiples bloques de función | +| `glm47` | `glm4` | GLM-4.7, GLM-4.7-Flash | `` con XML ``/`` | + +## Ejemplos por modelo + +### Mistral / Devstral + +```bash +# Devstral Small (optimizado para código y tool use) +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice --tool-call-parser mistral + +# Mistral Instruct +vllm-mlx serve mlx-community/Mistral-7B-Instruct-v0.3-4bit \ + --enable-auto-tool-choice --tool-call-parser mistral +``` + +### Qwen + +```bash +# Qwen3 +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --enable-auto-tool-choice --tool-call-parser qwen +``` + +### Llama + +```bash +# Llama 3.2 +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit \ + --enable-auto-tool-choice --tool-call-parser llama +``` + +### DeepSeek + +```bash +# DeepSeek V3 +vllm-mlx serve mlx-community/DeepSeek-V3-0324-4bit \ + --enable-auto-tool-choice --tool-call-parser deepseek +``` + +### IBM Granite + +```bash +# Granite 4.0 +vllm-mlx serve mlx-community/granite-4.0-tiny-preview-4bit \ + --enable-auto-tool-choice --tool-call-parser granite +``` + +### NVIDIA Nemotron + +```bash +# Nemotron 3 Nano +vllm-mlx serve mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-6Bit \ + --enable-auto-tool-choice --tool-call-parser nemotron +``` + +### GLM-4.7 + +```bash +# GLM-4.7 Flash +vllm-mlx serve lmstudio-community/GLM-4.7-Flash-MLX-8bit \ + --enable-auto-tool-choice --tool-call-parser glm47 +``` + +### Kimi K2 + +```bash +# Kimi K2 +vllm-mlx serve mlx-community/Kimi-K2-Instruct-4bit \ + --enable-auto-tool-choice --tool-call-parser kimi +``` + +### Salesforce xLAM + +```bash +# xLAM +vllm-mlx serve mlx-community/xLAM-2-fc-r-4bit \ + --enable-auto-tool-choice --tool-call-parser xlam +``` + +## Parser automático + +Si no sabes qué parser usar, el parser `auto` intenta detectar el formato de forma automática: + +```bash +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --enable-auto-tool-choice --tool-call-parser auto +``` + +El parser automático prueba los formatos en este orden: +1. Mistral (`[TOOL_CALLS]`) +2. Qwen con corchetes (`[Calling tool:]`) +3. Nemotron (``) +4. XML de Qwen/Hermes (`{...}`) +5. Llama (`{...}`) +6. JSON sin formato + +## Streaming de tool calls + +Los tool calls funcionan con streaming. La información del tool call se envía cuando el modelo termina de generarla: + +```python +stream = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What's 25 * 17?"}], + tools=[{ + "type": "function", + "function": { + "name": "calculator", + "description": "Calculate math expressions", + "parameters": { + "type": "object", + "properties": { + "expression": {"type": "string"} + }, + "required": ["expression"] + } + } + }], + stream=True +) + +for chunk in stream: + if chunk.choices[0].delta.tool_calls: + for tc in chunk.choices[0].delta.tool_calls: + print(f"Tool call: {tc.function.name}({tc.function.arguments})") +``` + +## Manejo de resultados de herramientas + +Después de recibir un tool call, ejecuta la función y devuelve el resultado: + +```python +import json + +# Primera solicitud: el modelo decide llamar a una herramienta +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What's the weather in Tokyo?"}], + tools=[weather_tool] +) + +# Obtener el tool call +tool_call = response.choices[0].message.tool_calls[0] +tool_call_id = tool_call.id +function_name = tool_call.function.name +arguments = json.loads(tool_call.function.arguments) + +# Ejecutar la función (implementación propia) +result = get_weather(**arguments) # {"temperature": 22, "condition": "sunny"} + +# Enviar el resultado de vuelta al modelo +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "user", "content": "What's the weather in Tokyo?"}, + {"role": "assistant", "tool_calls": [tool_call]}, + {"role": "tool", "tool_call_id": tool_call_id, "content": json.dumps(result)} + ], + tools=[weather_tool] +) + +print(response.choices[0].message.content) +# "The weather in Tokyo is sunny with a temperature of 22C." +``` + +## Manejo de etiquetas de razonamiento + +Los modelos que producen etiquetas de razonamiento `...` (como DeepSeek-R1, Qwen3, GLM-4.7) se manejan de forma automática. El parser elimina el contenido de reasoning antes de extraer los tool calls, por lo que las etiquetas de razonamiento nunca interfieren con el análisis de tool calls. + +Esto funciona incluso cuando `` fue inyectado en el prompt (etiquetas implícitas con solo un cierre ``). + +## Referencia de CLI + +| Opción | Descripción | +|--------|-------------| +| `--enable-auto-tool-choice` | Activa el tool calling automático | +| `--tool-call-parser` | Selecciona el parser (ver tabla anterior) | + +Consulta la [Referencia de CLI](../reference/cli.md) para todas las opciones. diff --git a/docs/es/guides/warm-prompts.md b/docs/es/guides/warm-prompts.md new file mode 100644 index 000000000..c1d3b7710 --- /dev/null +++ b/docs/es/guides/warm-prompts.md @@ -0,0 +1,192 @@ +# Warm Prompts + +Pre-pobla el prefix cache al iniciar el servidor para que la **primera** solicitud +que envie un agent encuentre un cache caliente en lugar de pagar el prefill completo +de su system prompt de varios kilobytes. + +## Cuándo usar esto + +Las cargas de trabajo de agents, proxies hacia asistentes de código o razonamiento, +servidores MCP, orquestadores multi-agent, siempre envian el mismo system prompt. +Hoy, la primera solicitud desde un servidor frio paga el prefill completo de ese +sistema. En un modelo de miles de millones de parámetros eso equivale a varios +segundos de TTFT, justo cuando un usuario espera que su nuevo agent responda por +primera vez. + +Si ya conoces los system prompts de tus agents al momento del despliegue, escríbelos +en un archivo JSON y apunta `--warm-prompts` hacia él. El servidor ejecuta un chat +completion de `max_tokens=1` para cada uno al inicio, el estado del KV cache queda +en el prefix cache, y la primera solicitud real coincide via strict-prefix. + +Requiere `--continuous-batching` (el prefix cache vive ahí). + +## Ejemplo rápido + +```bash +# Write the agents you care about once +cat > ~/.config/vllm-mlx/agents.json <<'JSON' +[ + [{"role": "system", "content": "You are a code assistant..."}] +] +JSON + +# Point the server at it +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --continuous-batching \ + --warm-prompts ~/.config/vllm-mlx/agents.json +``` + +Al iniciar verás: + +``` +[lifespan] Warm-up done (strict-prefix): 1 completed, 0 skipped, + 1431 prompt tokens in 0.2s +``` + +La primera solicitud real que comparte el system prompt calentado accede al cache +con `tokens_saved` cercano a la longitud del prompt de calentamiento. + +## Formato del archivo + +Una lista JSON de nivel superior. Cada entrada es a su vez una lista de mensajes +de chat, con la misma forma que `messages` en `/v1/chat/completions`. + +```json +[ + [ + {"role": "system", "content": "You are a code assistant..."} + ], + [ + {"role": "system", "content": "You are a senior code reviewer..."} + ], + [ + {"role": "system", "content": "You are a planner..."}, + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "Hello, what are we planning?"} + ] +] +``` + +Los system prompts de un solo mensaje son el caso más común. Los historiales +multi-turno son compatibles para escenarios en los que quieres calentar un inicio +de conversación específico (ejemplos few-shot, una persona de asistente fija). + +## Dimensionamiento + +Los warm prompts se procesan **de forma concurrente** via `asyncio.gather`, por lo +que N entradas lanzan N prefills concurrentes al inicio. Cada prefill asigna KV +cache según la longitud de su prompt. + +**Recomendado: 1 a 3 entradas.** Eso cubre los caminos calientes de despliegues +típicos de agents (una persona por entrada). Un archivo warm-prompts muy grande en +un modelo con poca memoria puede agotar el espacio disponible en el arranque. + +Si necesitas calentar decenas de personas, abre un issue con tu carga de trabajo y +podemos agregar un limite `--warm-prompts-concurrency=N`. + +## Benchmarks + +**Configuración.** M4 Max, 128 GB de memoria unificada. Dos servidores separados por +medición (frio vs caliente), arranque frio aislado. Conjunto de prompts `long` +(aprox. 2.5k tokens de usuario) antepuesto con un system prompt de aprox. 1.7k +tokens para coincidir con el warm prompt. `max_tokens=128`. bench-serve con +`--skip-preflight-token-count` para que el preflight de count_prompt_tokens no +contamine el cache. + +| Model | conc | cold TTFT | warm TTFT | Speedup | +|-------|-----:|----------:|----------:|--------:| +| Qwen3-0.6B-8bit | 1 | 563 ms | 419 ms | 1.34x | +| Qwen3-0.6B-8bit | 4 | 1 723 ms | 1 282 ms | 1.34x | +| Qwen3-0.6B-8bit | 8 | 3 708 ms | 2 661 ms | 1.39x | +| Llama-3.2-3B-Instruct-4bit | 1 | 1 754 ms | 1 060 ms | 1.65x | +| Llama-3.2-3B-Instruct-4bit | 4 | 5 926 ms | 3 945 ms | 1.50x | +| Llama-3.2-3B-Instruct-4bit | 8 | 15 161 ms | 9 820 ms | 1.54x | +| Qwen3-4B-4bit | 1 | 4 937 ms | 2 191 ms | 2.25x | +| Qwen3-4B-4bit | 4 | 12 535 ms | 9 623 ms | 1.30x | +| Qwen3-4B-4bit | 8 | 38 148 ms | 23 878 ms | 1.60x | +| Qwen3.6-35B-A3B-4bit (MoE/hybrid) | 1 | 2 400 ms | 1 603 ms | 1.50x | +| Qwen3.6-35B-A3B-4bit | 4 | 8 735 ms | 6 054 ms | 1.44x | +| Qwen3.6-35B-A3B-4bit | 8 | 22 419 ms | 14 409 ms | 1.56x | + +Las 12 configuraciones mejoran. Los ahorros de TTFT son mayores cuando la relación +prompt/total es más alta (conc=1, system prompt largo) y siguen siendo significativos +bajo carga concurrente. + +**Generation tok/s** es neutral (dentro de +-5%) para los modelos densos. +Qwen3.6-35B-A3B (MoE) muestra una caida en decode del 20 al 35% con conc >= 4, +que parece ser una interacción del enrutamiento MoE con el scheduling en batch. Los +ahorros de TTFT siguen dominando la latencia extremo a extremo en cargas de trabajo +de agents, pero toma nota de esto si tu flujo es fuertemente decode-bound a alta +concurrencia. + +## Cómo funciona + +El calentamiento naive, renderizar la plantilla de chat con un mensaje de usuario de +relleno y cachear los tokens, no funciona para modelos híbridos SSM+attention +(Qwen3.5-MoE, Qwen3.6-MoE). Sus capas de cache incluyen estado SSM que no puede +recortarse, por lo que `memory_cache.py` deshabilita la coincidencia LCP. El +contenido de usuario de relleno diverge del contenido real del usuario y una entrada +cacheada a nivel de tokens ya no es un strict-prefix de ninguna solicitud real. + +El calentador aquí renderiza la plantilla de chat **dos veces** con dos contenidos de +usuario distintos (`"__PROBE_A__"` y `"__PROBE_B__"`), encuentra la posición de +carácter donde las dos cadenas divergen y trunca el primer renderizado en ese limite. +Esa cadena truncada, todo lo que precede al punto donde se inserta el contenido del +usuario, es lo que se envía al motor. + +Dado que el flujo de solicitudes reales del motor también renderiza la plantilla con +`tokenize=False` y luego deja que el tokenizador codifique el resultado, los tokens +del calentamiento tienen garantia de ser un strict-prefix de cualquier solicitud real +con un sistema coincidente e historial de chat vacío. Las coincidencias strict-prefix +funcionan en todo tipo de capas de cache, incluidos los flujos híbridos donde el LCP +está deshabilitado. + +## Administración + +### Limpiar el prefix cache en memoria + +```bash +curl -X DELETE http://localhost:8000/v1/cache/prefix +``` + +Si el servidor se inició con `--warm-prompts`, el calentamiento se vuelve a ejecutar +en segundo plano después de la limpieza. La respuesta se devuelve de inmediato sin +esperar a que termine el re-calentamiento. + +Respuesta: + +```json +{"status": "cleared", "rewarm_scheduled": true} +``` + +### Inspeccionar el estado del cache + +```bash +curl http://localhost:8000/v1/status | jq '.cache' +``` + +Tras el arranque con warm-prompts verás `entry_count > 0` antes de la primera +solicitud del usuario. + +## Benchmark de tu propia configuración + +Para medir el impacto en tu modelo y tus prompts, usa `bench-serve`: + +```bash +# Cold: no warm-prompts +vllm-mlx serve MODEL --continuous-batching & +vllm-mlx bench-serve --prompts long --concurrency 1,4 \ + --system-prompt-file my-system.txt --tag cold \ + --output cold.csv --format csv + +# Warm: same server config + --warm-prompts +vllm-mlx serve MODEL --continuous-batching \ + --warm-prompts ~/.config/vllm-mlx/agents.json & +vllm-mlx bench-serve --prompts long --concurrency 1,4 \ + --system-prompt-file my-system.txt --tag warm \ + --output warm.csv --format csv +``` + +`--skip-preflight-token-count` se habilita automáticamente cuando se usa +`--system-prompt-file`, por lo que el preflight de `count_prompt_tokens` no +contamina el cache. Compara `cold.csv` y `warm.csv` para tu carga de trabajo. diff --git a/docs/es/index.md b/docs/es/index.md new file mode 100644 index 000000000..919f961f5 --- /dev/null +++ b/docs/es/index.md @@ -0,0 +1,66 @@ +# Documentación de vLLM-MLX + +**Backend MLX para Apple Silicon en vLLM** - Aceleración GPU para texto, imagen, video y audio en Mac + +## ¿Qué es vLLM-MLX? + +vllm-mlx incorpora aceleración GPU nativa de Apple Silicon a vLLM mediante la integración de: + +- **[MLX](https://github.com/ml-explore/mlx)**: El framework de ML de Apple con memoria unificada y kernels Metal +- **[mlx-lm](https://github.com/ml-explore/mlx-lm)**: Inferencia LLM optimizada con KV cache y cuantización +- **[mlx-vlm](https://github.com/Blaizzy/mlx-vlm)**: Modelos visión-lenguaje para inferencia multimodal +- **[mlx-audio](https://github.com/Blaizzy/mlx-audio)**: TTS y STT con voces nativas +- **[mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings)**: Embeddings de texto para búsqueda semántica y RAG + +## Características principales + +- **Multimodal** - Texto, imagen, video y audio en una sola plataforma +- **Aceleración GPU nativa** en Apple Silicon (M1, M2, M3, M4) +- **Voces TTS nativas** - Español, francés, chino, japonés y 5 idiomas más +- **Compatible con la API de OpenAI** - reemplazo directo del cliente de OpenAI +- **Embeddings** - Endpoint `/v1/embeddings` compatible con OpenAI +- **MCP Tool Calling** - integración de herramientas externas mediante el Model Context Protocol +- **Paged KV Cache** - almacenamiento en caché eficiente en memoria con prefix sharing +- **Continuous Batching** - alto rendimiento para múltiples usuarios concurrentes + +## Enlaces rápidos + +### Primeros pasos +- [Instalación](getting-started/installation.md) +- [Inicio rápido](getting-started/quickstart.md) + +### Guías de usuario +- [Servidor compatible con OpenAI](guides/server.md) +- [API de Python](guides/python-api.md) +- [Multimodal (imágenes y video)](guides/multimodal.md) +- [Audio (STT/TTS)](guides/audio.md) +- [Embeddings](guides/embeddings.md) +- [Modelos de reasoning](guides/reasoning.md) +- [Tool Calling](guides/tool-calling.md) +- [MCP y Tool Calling](guides/mcp-tools.md) +- [Continuous Batching](guides/continuous-batching.md) + +### Referencia +- [Comandos CLI](reference/cli.md) +- [Modelos compatibles](reference/models.md) +- [Configuración](reference/configuration.md) + +### Benchmarks +- [Benchmarks LLM](benchmarks/llm.md) +- [Benchmarks de imagen](benchmarks/image.md) +- [Benchmarks de video](benchmarks/video.md) +- [Benchmarks de audio](benchmarks/audio.md) + +### Desarrollo +- [Arquitectura](../development/architecture.md) +- [Contribuir](../development/contributing.md) + +## Requisitos + +- macOS en Apple Silicon (M1/M2/M3/M4) +- Python 3.10+ +- Se recomiendan 8 GB de RAM o más + +## Licencia + +Apache 2.0 - Consulta [LICENSE](../../LICENSE) para más detalles. diff --git a/docs/es/reference/cli.md b/docs/es/reference/cli.md new file mode 100644 index 000000000..0eeeea141 --- /dev/null +++ b/docs/es/reference/cli.md @@ -0,0 +1,210 @@ +# Referencia de CLI + +## Resumen de comandos + +| Comando | Descripcion | +|---------|-------------| +| `vllm-mlx serve` | Inicia el servidor compatible con OpenAI | +| `vllm-mlx-bench` | Ejecuta benchmarks de rendimiento | +| `vllm-mlx-chat` | Inicia la interfaz de chat con Gradio | + +## `vllm-mlx serve` + +Inicia el servidor de API compatible con OpenAI. + +### Uso + +```bash +vllm-mlx serve [options] +``` + +### Opciones + +| Opcion | Descripcion | Por defecto | +|--------|-------------|-------------| +| `--served-model-name` | Nombre personalizado del modelo expuesto a traves de la API de OpenAI. Si no se especifica, se usa la ruta del modelo como nombre. | None | +| `--port` | Puerto del servidor | 8000 | +| `--host` | Host del servidor | 127.0.0.1 | +| `--api-key` | Clave de API para autenticacion | None | +| `--rate-limit` | Solicitudes por minuto por cliente (0 = desactivado) | 0 | +| `--timeout` | Tiempo limite de solicitud en segundos | 300 | +| `--enable-metrics` | Expone métricas de Prometheus en `/metrics` | False | +| `--continuous-batching` | Activa continuous batching para multiples usuarios | False | +| `--cache-memory-mb` | Limite de memoria para cache en MB | Auto | +| `--cache-memory-percent` | Fraccion de RAM para cache | 0.20 | +| `--no-memory-aware-cache` | Usa cache legacy basado en conteo de entradas | False | +| `--use-paged-cache` | Activa el KV cache paginado | False | +| `--max-tokens` | Maximo de tokens por defecto | 32768 | +| `--max-request-tokens` | Maximo de `max_tokens` aceptado desde clientes de la API | 32768 | +| `--stream-interval` | Tokens por fragmento de streaming | 1 | +| `--mcp-config` | Ruta al archivo de configuración de MCP | None | +| `--paged-cache-block-size` | Tokens por bloque de cache | 64 | +| `--max-cache-blocks` | Maximos bloques de cache | 1000 | +| `--max-num-seqs` | Maximo de secuencias concurrentes | 256 | +| `--default-temperature` | Temperatura por defecto cuando no se especifica en la solicitud | None | +| `--default-top-p` | top_p por defecto cuando no se especifica en la solicitud | None | +| `--max-audio-upload-mb` | Tamano máximo de audio subido para `/v1/audio/transcriptions` | 25 | +| `--max-tts-input-chars` | Longitud máxima de texto aceptada por `/v1/audio/speech` | 4096 | +| `--reasoning-parser` | Parser para modelos de reasoning (`qwen3`, `deepseek_r1`) | None | +| `--embedding-model` | Pre-carga un modelo de embeddings al iniciar | None | +| `--enable-auto-tool-choice` | Activa tool calling automático | False | +| `--tool-call-parser` | Parser de tool calling (`auto`, `mistral`, `qwen`, `llama`, `hermes`, `deepseek`, `kimi`, `granite`, `nemotron`, `xlam`, `functionary`, `glm47`) | None | + +### Ejemplos + +```bash +# Modo simple (usuario único, máximo rendimiento) +# La ruta del modelo se usa como nombre en la API de OpenAI (ej. model="mlx-community/Llama-3.2-3B-Instruct-4bit") +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit + +Model will show up as 'mlx-community/Llama-3.2-3B-Instruct-4bit' in the `/v1/models` API endpoint. View with `curl http://localhost:8000/v1/models` or similar. + +# Con un nombre de modelo personalizado en la API (el modelo se accede como "my-model" via la API de OpenAI) +# --served-model-name establece el nombre que los clientes deben usar al llamar a la API (ej. model="my-model") +vllm-mlx serve --served-model-name my-model mlx-community/Llama-3.2-3B-Instruct-4bit +# Note: Model will show up as 'my-model' in the `/v1/models` API endpoint. + +# Continuous batching (multiples usuarios) +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --continuous-batching + +# Con limite de memoria para modelos grandes +vllm-mlx serve mlx-community/GLM-4.7-Flash-4bit \ + --continuous-batching \ + --cache-memory-mb 2048 + +# Produccion con paged cache +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching \ + --use-paged-cache \ + --port 8000 + +# Con herramientas MCP +vllm-mlx serve mlx-community/Qwen3-4B-4bit --mcp-config mcp.json + +# Modelo multimodal +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit + +# Modelo de reasoning (separa el pensamiento de la respuesta) +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 + +# Modelo de reasoning DeepSeek +vllm-mlx serve mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit --reasoning-parser deepseek_r1 + +# Tool calling con Mistral/Devstral +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice --tool-call-parser mistral + +# Tool calling con Granite +vllm-mlx serve mlx-community/granite-4.0-tiny-preview-4bit \ + --enable-auto-tool-choice --tool-call-parser granite + +# Con autenticacion por clave de API +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --api-key your-secret-key + +# Exponer métricas de Prometheus +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --enable-metrics + +# Configuracion de produccion con opciones de seguridad +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --api-key your-secret-key \ + --rate-limit 60 \ + --timeout 120 \ + --continuous-batching +``` + +### Seguridad + +Cuando se establece `--api-key`, todas las solicitudes a la API requieren el encabezado `Authorization: Bearer `: + +```python +from openai import OpenAI + +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="your-secret-key" # Must match --api-key +) +``` + +O con curl: + +```bash +curl http://localhost:8000/v1/models \ + -H "Authorization: Bearer your-secret-key" +``` + +## `vllm-mlx-bench` + +Ejecuta benchmarks de rendimiento. + +### Uso + +```bash +vllm-mlx-bench --model [options] +``` + +### Opciones + +| Opcion | Descripcion | Por defecto | +|--------|-------------|-------------| +| `--model` | Nombre del modelo | Requerido | +| `--prompts` | Numero de prompts | 5 | +| `--max-tokens` | Maximo de tokens por prompt | 256 | +| `--quick` | Modo de benchmark rápido | False | +| `--video` | Ejecutar benchmark de video | False | +| `--video-url` | URL de video personalizada | None | +| `--video-path` | Ruta de video personalizada | None | + +### Ejemplos + +```bash +# Benchmark de LLM +vllm-mlx-bench --model mlx-community/Llama-3.2-1B-Instruct-4bit + +# Benchmark rápido +vllm-mlx-bench --model mlx-community/Llama-3.2-1B-Instruct-4bit --quick + +# Benchmark de imagenes (deteccion automática para modelos VLM) +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit + +# Benchmark de video +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video + +# Video personalizado +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit \ + --video --video-url https://example.com/video.mp4 +``` + +## `vllm-mlx-chat` + +Inicia la interfaz de chat con Gradio. + +### Uso + +```bash +vllm-mlx-chat --served-model-name [options] +``` + +### Opciones + +| Opcion | Descripcion | Por defecto | +|--------|-------------|-------------| +| `--model` | Nombre del modelo | Requerido | +| `--port` | Puerto de Gradio | 7860 | +| `--text-only` | Desactiva el modo multimodal | False | + +### Ejemplos + +```bash +# Chat multimodal (texto + imagenes + video) +vllm-mlx-chat --served-model-name mlx-community/Qwen3-VL-4B-Instruct-3bit + +# Chat solo de texto +vllm-mlx-chat --served-model-name mlx-community/Llama-3.2-3B-Instruct-4bit --text-only +``` + +## Variables de entorno + +| Variable | Descripcion | +|----------|-------------| +| `VLLM_MLX_TEST_MODEL` | Modelo para pruebas | +| `HF_TOKEN` | Token de HuggingFace | diff --git a/docs/es/reference/configuration.md b/docs/es/reference/configuration.md new file mode 100644 index 000000000..a7406a674 --- /dev/null +++ b/docs/es/reference/configuration.md @@ -0,0 +1,189 @@ +# Referencia de configuración + +## Configuracion del servidor + +### Opciones basicas + +| Opcion | Descripcion | Valor por defecto | +|--------|-------------|---------| +| `--host` | Direccion del host del servidor | `127.0.0.1` | +| `--port` | Puerto del servidor | `8000` | +| `--max-tokens` | Maximo de tokens por defecto | `32768` | +| `--max-request-tokens` | Maximo de `max_tokens` aceptado de clientes de la API | `32768` | +| `--default-temperature` | Temperatura por defecto cuando no se especifica en la solicitud | None | +| `--default-top-p` | top_p por defecto cuando no se especifica en la solicitud | None | + +### Opciones de seguridad + +| Opcion | Descripcion | Valor por defecto | +|--------|-------------|---------| +| `--api-key` | Clave de API para autenticacion | None | +| `--rate-limit` | Solicitudes por minuto por cliente (0 = deshabilitado) | `0` | +| `--timeout` | Tiempo de espera de la solicitud en segundos | `300` | +| `--enable-metrics` | Expone métricas de Prometheus en `/metrics` | `false` | +| `--max-audio-upload-mb` | Tamano máximo de audio subido para `/v1/audio/transcriptions` | `25` | +| `--max-tts-input-chars` | Longitud máxima de texto aceptada por `/v1/audio/speech` | `4096` | + +### Opciones de batching + +| Opcion | Descripcion | Valor por defecto | +|--------|-------------|---------| +| `--continuous-batching` | Habilita el continuous batching | `false` | +| `--stream-interval` | Tokens por fragmento de streaming | `1` | +| `--max-num-seqs` | Maximo de secuencias concurrentes | `256` | + +### Opciones de cache + +| Opcion | Descripcion | Valor por defecto | +|--------|-------------|---------| +| `--cache-memory-mb` | Limite de memoria de cache en MB | Auto | +| `--cache-memory-percent` | Fraccion de RAM para cache | `0.20` | +| `--no-memory-aware-cache` | Usa cache de conteo de entradas heredado | `false` | +| `--use-paged-cache` | Habilita el KV cache paginado | `false` | +| `--paged-cache-block-size` | Tokens por bloque | `64` | +| `--max-cache-blocks` | Maximo de bloques | `1000` | + +### Opciones de llamado a herramientas + +| Opcion | Descripcion | Valor por defecto | +|--------|-------------|---------| +| `--enable-auto-tool-choice` | Habilita el llamado automático a herramientas | `false` | +| `--tool-call-parser` | Parser de llamados a herramientas (ver [Tool Calling](../guides/tool-calling.md)) | None | + +### Opciones de razonamiento + +| Opcion | Descripcion | Valor por defecto | +|--------|-------------|---------| +| `--reasoning-parser` | Parser para modelos de razonamiento (`qwen3`, `deepseek_r1`) | None | + +### Opciones de embeddings + +| Opcion | Descripcion | Valor por defecto | +|--------|-------------|---------| +| `--embedding-model` | Precarga un modelo de embeddings al iniciar | None | + +### Opciones de MCP + +| Opcion | Descripcion | Valor por defecto | +|--------|-------------|---------| +| `--mcp-config` | Ruta al archivo de configuración MCP | None | + +## Configuracion de MCP + +Crear `mcp.json`: + +```json +{ + "mcpServers": { + "server-name": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-name", "arg1"], + "env": { + "ENV_VAR": "value" + } + } + } +} +``` + +### Opciones del servidor MCP + +| Campo | Descripcion | Requerido | +|-------|-------------|----------| +| `command` | Comando ejecutable | Si | +| `args` | Argumentos del comando | Si | +| `env` | Variables de entorno | No | + +## Opciones de solicitudes a la API + +### Chat Completions + +| Parametro | Descripcion | Valor por defecto | +|-----------|-------------|---------| +| `model` | Nombre del modelo | Requerido | +| `messages` | Mensajes del chat | Requerido | +| `max_tokens` | Maximo de tokens a generar | 256 | +| `temperature` | Temperatura de muestreo | Valor por defecto del modelo | +| `top_p` | Nucleus sampling | Valor por defecto del modelo | +| `stream` | Habilita el streaming | `true` | +| `stop` | Secuencias de detencion | None | +| `tools` | Definiciones de herramientas | None | +| `response_format` | Formato de salida (`json_object`, `json_schema`) | None | + +### Opciones multimodales + +| Parametro | Descripcion | Valor por defecto | +|-----------|-------------|---------| +| `video_fps` | Fotogramas por segundo | 2.0 | +| `video_max_frames` | Maximo de fotogramas | 32 | + +## Variables de entorno + +| Variable | Descripcion | +|----------|-------------| +| `VLLM_MLX_TEST_MODEL` | Modelo por defecto para pruebas | +| `HF_TOKEN` | Token de autenticacion de HuggingFace | +| `OPENAI_API_KEY` | Establecer a cualquier valor para compatibilidad con el SDK | + +## Configuraciones de ejemplo + +### Desarrollo (usuario único) + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit +``` + +### Produccion (multiples usuarios) + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching \ + --use-paged-cache \ + --api-key your-secret-key \ + --rate-limit 60 \ + --port 8000 +``` + +### Con llamado a herramientas + +```bash +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice \ + --tool-call-parser mistral \ + --continuous-batching +``` + +### Con herramientas MCP + +```bash +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --mcp-config mcp.json \ + --enable-auto-tool-choice \ + --tool-call-parser qwen \ + --continuous-batching +``` + +### Modelo de razonamiento + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit \ + --reasoning-parser qwen3 \ + --continuous-batching +``` + +### Con embeddings + +```bash +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --embedding-model mlx-community/multilingual-e5-small-mlx \ + --continuous-batching +``` + +### Alto rendimiento + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching \ + --stream-interval 5 \ + --max-num-seqs 256 +``` diff --git a/docs/es/reference/models.md b/docs/es/reference/models.md new file mode 100644 index 000000000..13b645b56 --- /dev/null +++ b/docs/es/reference/models.md @@ -0,0 +1,99 @@ +# Modelos compatibles + +Todos los modelos cuantizados de [mlx-community en HuggingFace](https://huggingface.co/mlx-community/models) son compatibles. + +Explora miles de modelos preoptimizados en: **https://huggingface.co/mlx-community/models** + +## Modelos de lenguaje (vía mlx-lm) + +| Familia de modelos | Tamaños | Cuantización | +|--------------------|---------|--------------| +| Llama 3.x, 4.x | 1B, 3B, 8B, 70B | 4-bit | +| Mistral / Devstral | 7B, Mixtral 8x7B | 4-bit, 8-bit | +| Qwen2/Qwen3 | 0.5B a 72B | Varios | +| DeepSeek V3, R1 | 7B, 33B, 67B | 4-bit | +| Gemma 2, 3, 4 | 2B, 9B, 27B | 4-bit | +| GLM-4.7 | Flash, Base | 4-bit, 8-bit | +| Kimi K2 | Varios | 4-bit | +| Phi-3 | 3.8B, 14B | 4-bit | +| Granite 3.x, 4.x | Varios | 4-bit | +| Nemotron | 3 Nano 30B | 6-bit | + +### Modelos recomendados + +| Caso de uso | Modelo | Memoria | +|-------------|--------|---------| +| Rápido / Liviano | `mlx-community/Qwen3-0.6B-8bit` | ~0.7 GB | +| Equilibrado | `mlx-community/Llama-3.2-3B-Instruct-4bit` | ~1.8 GB | +| Calidad | `mlx-community/Llama-3.1-8B-Instruct-4bit` | ~4.5 GB | +| Grande | `mlx-community/Qwen3-30B-A3B-4bit` | ~16 GB | + +## Modelos multimodales (vía mlx-vlm) + +| Familia de modelos | Modelos de ejemplo | +|--------------------|--------------------| +| **Qwen-VL** | `Qwen3-VL-4B-Instruct-3bit`, `Qwen3-VL-8B-Instruct-4bit`, `Qwen2-VL-2B/7B-Instruct-4bit` | +| **LLaVA** | `llava-1.5-7b-4bit`, `llava-v1.6-mistral-7b-4bit`, `llava-llama-3-8b-v1_1-4bit` | +| **Idefics** | `Idefics3-8B-Llama3-4bit`, `idefics2-8b-4bit` | +| **Gemma 4** | `gemma-4-e2b-it-mxfp4` (visión + audio) | +| **PaliGemma** | `paligemma2-3b-mix-224-4bit`, `paligemma-3b-mix-224-8bit` | +| **Pixtral** | `pixtral-12b-4bit`, `pixtral-12b-8bit` | +| **Molmo** | `Molmo-7B-D-0924-4bit`, `Molmo-7B-D-0924-8bit` | +| **Phi-3 Vision** | `Phi-3-vision-128k-instruct-4bit` | +| **DeepSeek-VL** | `deepseek-vl-7b-chat-4bit`, `deepseek-vl2-small-4bit` | + +### Modelos VLM recomendados + +| Caso de uso | Modelo | Memoria | +|-------------|--------|---------| +| Rápido / Liviano | `mlx-community/Qwen3-VL-4B-Instruct-3bit` | ~3 GB | +| Equilibrado | `mlx-community/Qwen3-VL-8B-Instruct-4bit` | ~6 GB | +| Calidad | `mlx-community/Qwen3-VL-30B-A3B-Instruct-6bit` | ~20 GB | + +## Modelos de embeddings (vía mlx-embeddings) + +| Familia de modelos | Modelos de ejemplo | +|--------------------|--------------------| +| **BERT** | `mlx-community/bert-base-uncased-mlx` | +| **XLM-RoBERTa** | `mlx-community/multilingual-e5-small-mlx`, `mlx-community/multilingual-e5-large-mlx` | +| **ModernBERT** | `mlx-community/ModernBERT-base-mlx` | + +## Modelos de audio (vía mlx-audio) + +| Tipo | Familia de modelos | Modelos de ejemplo | +|------|--------------------|--------------------| +| **STT** | Whisper | `mlx-community/whisper-large-v3-turbo` | +| **STT** | Parakeet | `mlx-community/parakeet-tdt-0.6b-v2` | +| **TTS** | Kokoro | `prince-canuma/Kokoro-82M` | +| **TTS** | Chatterbox | `chatterbox/chatterbox-tts-0.1` | + +## Detección de modelos + +vllm-mlx detecta automáticamente los modelos multimodales por patrones en el nombre: +- Contiene "VL", "Vision", "vision" +- Contiene "llava", "idefics", "paligemma" +- Contiene "pixtral", "molmo", "deepseek-vl" +- Contiene "MedGemma", "Gemma-3", "Gemma-4" (variantes multimodales) + +## Usar modelos + +### Desde HuggingFace + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit +``` + +### Ruta local + +```bash +vllm-mlx serve /path/to/local/model +``` + +## Buscar modelos + +Filtra los modelos de mlx-community por: +- **LLM**: `Llama`, `Qwen`, `Mistral`, `Phi`, `Gemma`, `DeepSeek`, `GLM`, `Kimi`, `Granite`, `Nemotron` +- **VLM**: `-VL-`, `llava`, `paligemma`, `pixtral`, `molmo`, `idefics`, `deepseek-vl`, `MedGemma` +- **Embedding**: `e5`, `bert`, `ModernBERT` +- **Tamaño**: `1B`, `3B`, `7B`, `8B`, `70B` +- **Cuantización**: `4bit`, `8bit`, `bf16` diff --git a/docs/fr/benchmarks/README.md b/docs/fr/benchmarks/README.md new file mode 100644 index 000000000..0a226beba --- /dev/null +++ b/docs/fr/benchmarks/README.md @@ -0,0 +1,63 @@ +# Benchmarks + +Benchmarks de performance pour vllm-mlx sur Apple Silicon. + +## Types de benchmarks + +- [Benchmarks LLM](llm.md) - Performance de génération de texte +- [Benchmarks image](image.md) - Performance de compréhension d'images +- [Benchmarks vidéo](video.md) - Performance de compréhension de vidéos + +## Commandes rapides + +```bash +# LLM benchmark +vllm-mlx-bench --model mlx-community/Qwen3-0.6B-8bit + +# Image benchmark +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit + +# Video benchmark +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video +``` + +## Valeurs par défaut des scripts de test autonomes + +Les scripts de benchmark autonomes disposent de modèles par défaut intégrés, ce qui permet de les lancer directement : + +```bash +python tests/test_continuous_batching.py +python tests/test_prefix_cache.py +``` + +Valeurs par défaut : +- `tests/test_continuous_batching.py` → `mlx-community/Qwen3-8B-6bit` +- `tests/test_prefix_cache.py` → `mlx-community/Qwen3-0.6B-8bit` + +Pour tester d'autres modèles, utilisez l'option `--model` : + +```bash +python tests/test_continuous_batching.py --model mlx-community/Qwen3-0.6B-8bit +python tests/test_prefix_cache.py --model mlx-community/Qwen3-8B-6bit +``` + +## Matériel + +Les benchmarks ont été collectés sur les configurations Apple Silicon suivantes : + +| Puce | Mémoire | Python | +|------|---------|--------| +| Apple M4 Max | 128 Go unifiée | 3.13 | +| Apple M1 Max | 64 Go unifiée | 3.12 | + +Les résultats varieront selon la puce Apple Silicon utilisée. + +## Contribuer des benchmarks + +Si vous disposez d'une puce Apple Silicon différente, partagez vos résultats : + +```bash +vllm-mlx-bench --model mlx-community/Qwen3-0.6B-8bit --output results.json +``` + +Ouvrez un ticket avec vos résultats sur [GitHub Issues](https://github.com/waybarrios/vllm-mlx/issues). diff --git a/docs/fr/benchmarks/audio.md b/docs/fr/benchmarks/audio.md new file mode 100644 index 000000000..a99b8a5c4 --- /dev/null +++ b/docs/fr/benchmarks/audio.md @@ -0,0 +1,158 @@ +# Benchmarks Audio + +## Benchmarks STT (Speech-to-Text) + +### Lancer les benchmarks STT + +```bash +# Run with default test audio +python examples/benchmark_audio.py --stt + +# Run with your own audio file +python examples/benchmark_audio.py --stt --audio path/to/audio.wav +``` + +### Résultats (M4 Max, 128 Go) + +**Audio de test :** 46,7 secondes de synthèse vocale + +| Model | Parameters | Load Time | Transcribe Time | RTF* | +|-------|------------|-----------|-----------------|------| +| whisper-tiny | 39M | 0.34s | 0.24s | **197x** | +| whisper-small | 244M | 0.18s | 0.47s | **98x** | +| whisper-medium | 769M | 0.35s | 1.15s | **41x** | +| whisper-large-v3 | 1.5B | 0.50s | 1.96s | **24x** | +| whisper-large-v3-turbo | 809M | 0.12s | 0.86s | **55x** | + +*RTF = Real-Time Factor (plus la valeur est élevée, plus c'est rapide). Un RTF de 100x signifie qu'une minute d'audio est transcrite en environ 0,6 secondes.* + +### Résultats (M1 Max, 64 Go) + +STT avec Parakeet (environnement par défaut, Whisper indisponible en raison d'une incompatibilité de dépendance numpy) : + +| Model | Load Time | Transcribe Time | RTF | +|-------|-----------|-----------------|-----| +| parakeet-tdt-0.6b-v2 | 0.28s | 1.01s | **9.9x** | +| parakeet-tdt-0.6b-v3 | 0.30s | 0.19s | **52.7x** | + +STT avec Whisper (`numpy==2.3.5` explicite + `uv run --no-sync`) : + +| Model | Load Time | Transcribe Time | RTF | +|-------|-----------|-----------------|-----| +| whisper-tiny | 4.02s | 1.05s | **9.5x** | +| whisper-small | 10.15s | 1.03s | **9.7x** | +| whisper-medium | 22.96s | 2.20s | **4.6x** | +| whisper-large-v3 | 38.34s | 0.96s | **10.5x** | +| whisper-large-v3-turbo | 21.79s | 0.70s | **14.3x** | +| parakeet-tdt-0.6b-v2 | 0.47s | 0.18s | **54.4x** | +| parakeet-tdt-0.6b-v3 | 1.13s | 0.18s | **54.6x** | + +### Recommandations de modèles + +| Use Case | Recommended Model | Why | +|----------|-------------------|-----| +| **Transcription en temps réel** | whisper-tiny | Le plus rapide (RTF 197x), faible latence | +| **Usage général** | whisper-large-v3-turbo | Meilleur compromis vitesse (55x) et qualité | +| **Précision maximale** | whisper-large-v3 | Le plus précis, prend en charge plus de 99 langues | +| **Mémoire limitée** | whisper-small | Bonne qualité à 244M paramètres | + +### Qualité de transcription + +Tous les modèles ont correctement transcrit l'audio de test. Exemple de sortie : + +``` +Input text: +"Welcome to this comprehensive speech to text demonstration. +This audio sample is designed to test the accuracy and speed of various speech recognition models. +The quick brown fox jumps over the lazy dog..." + +Whisper-large-v3 output: +"Welcome to this comprehensive speech to text demonstration. +This audio sample is designed to test the accuracy and speed of various speech recognition models. +The quick brown fox jumps over the lazy dog..." (identical) +``` + +### Langues prises en charge + +Les modèles Whisper prennent en charge plus de 99 langues, notamment : +- Anglais, espagnol, français, allemand, italien, portugais +- Chinois (mandarin, cantonais), japonais, coréen +- Arabe, hindi, russe, turc, ukrainien +- Et bien d'autres + +## Benchmarks TTS (Text-to-Speech) + +### Lancer les benchmarks TTS + +```bash +python examples/benchmark_audio.py --tts +``` + +### Résultats (M4 Max, 128 Go) + +**Test :** Génération audio pour 3 échantillons de texte (court, moyen, long) + +| Model | Load Time | Chars/sec | RTF* | +|-------|-----------|-----------|------| +| Kokoro-82M-bf16 | 0.8s | 350+ | **22x** | +| Kokoro-82M-4bit | 0.4s | 320+ | **20x** | + +*RTF = Real-Time Factor. Un RTF de 22x signifie qu'une seconde d'audio est générée en environ 0,045 secondes.* + +### Résultats TTS (M1 Max, 64 Go) + +| Model | Load Time | Avg Chars/s | Avg RTF | +|-------|-----------|-------------|---------| +| Kokoro-82M-bf16 | 2.81s | 176.0 | **11.9x** | +| Kokoro-82M-4bit | 0.22s | 225.6 | **15.5x** | + +### Qualité TTS + +Kokoro produit une synthèse vocale au son naturel avec : +- 11 voix intégrées (masculines et féminines) +- Prise en charge de 8 langues (anglais, espagnol, français, japonais, chinois, italien, portugais, hindi) +- 82M paramètres, rapide et léger + +## Benchmarks de traitement audio + +### SAM-Audio (séparation de sources) + +**Test :** Séparation de la batterie dans un morceau de rock de 30 secondes + +| Metric | Value | +|--------|-------| +| Model | sam-audio-large-fp16 | +| Processing time | ~20s | +| Peak memory | ~27 GB | +| Output sample rate | 48000 Hz | + +## Lancer tous les benchmarks audio + +```bash +# Run all benchmarks +python examples/benchmark_audio.py --all + +# Or run individually +python examples/benchmark_audio.py --stt +python examples/benchmark_audio.py --tts +``` + +## Modèles disponibles sur mlx-community + +### Modèles STT +- `mlx-community/whisper-tiny-mlx` +- `mlx-community/whisper-small-mlx` +- `mlx-community/whisper-medium-mlx` +- `mlx-community/whisper-large-v3-mlx` +- `mlx-community/whisper-large-v3-turbo` +- `mlx-community/parakeet-tdt-0.6b-v2` +- `mlx-community/parakeet-tdt-0.6b-v3` + +### Modèles TTS +- `mlx-community/Kokoro-82M-bf16` (recommandé) +- `mlx-community/Kokoro-82M-4bit` +- `mlx-community/chatterbox-turbo-fp16` +- `mlx-community/VibeVoice-Realtime-0.5B-4bit` + +### Traitement audio +- `mlx-community/sam-audio-large-fp16` diff --git a/docs/fr/benchmarks/image.md b/docs/fr/benchmarks/image.md new file mode 100644 index 000000000..f7eb46474 --- /dev/null +++ b/docs/fr/benchmarks/image.md @@ -0,0 +1,138 @@ +# Benchmarks d'images + +## Lancer les benchmarks d'images + +```bash +# Benchmark complet (10 résolutions) +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit + +# Benchmark rapide (4 résolutions) +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --quick +``` + +## Résultats - Qwen3-VL-8B-Instruct-4bit (M4 Max, 128GB) + +| Resolution | Pixels | Time | Tokens | Speed | +|------------|--------|------|--------|-------| +| 224x224 | 50K | 1.04s | 78 | 74.8 tok/s | +| 336x336 | 113K | 0.94s | 64 | 68.3 tok/s | +| 448x448 | 201K | 1.45s | 70 | 48.1 tok/s | +| 512x512 | 262K | 1.58s | 99 | 62.8 tok/s | +| 672x672 | 452K | 1.83s | 83 | 45.3 tok/s | +| 768x768 | 590K | 2.05s | 91 | 44.3 tok/s | +| 896x896 | 803K | 2.61s | 90 | 34.5 tok/s | +| 1024x1024 | 1.0M | 2.79s | 76 | 27.2 tok/s | +| 1280x720 | 922K | 2.97s | 96 | 32.4 tok/s | +| 1920x1080 | 2.1M | 6.30s | 89 | 14.1 tok/s | + +**Résumé :** Moyenne de 45.2 tok/s sur toutes les résolutions. Le plus rapide à 224x224 (74.8 tok/s), le plus lent à 1920x1080 (14.1 tok/s) + +## Résultats - Qwen3-VL-8B-Instruct-4bit (M1 Max, 64GB) + +Benchmark MLLM local : + +| Resolution | Pixels | Time | Tokens | Speed | +|------------|--------|------|--------|-------| +| 224x224 | 50K | 1.84s | 78 | 42.5 tok/s | +| 448x448 | 201K | 2.28s | 70 | 30.7 tok/s | +| 768x768 | 590K | 4.39s | 91 | 20.7 tok/s | +| 1024x1024 | 1.0M | 6.41s | 76 | 11.9 tok/s | + +| Configs | Total Time (s) | Total Tokens | Aggregate Tok/s | +|---------|-----------------|--------------|-----------------| +| 4 | 14.92 | 315 | 21.1 | + +## Résultats - Qwen3-VL-4B-Instruct-3bit Serveur (M1 Max, 64GB) + +| Resolution | Pixels | Time | Tokens | Speed | +|------------|--------|------|--------|-------| +| 224x224 | 50K | 1.65s | 113 | 68.4 tok/s | +| 448x448 | 201K | 2.09s | 120 | 57.5 tok/s | +| 768x768 | 590K | 2.93s | 106 | 36.2 tok/s | +| 1024x1024 | 1.0M | 4.12s | 100 | 24.3 tok/s | + +| Configs | Total Time (s) | Total Tokens | Aggregate Tok/s | +|---------|-----------------|--------------|-----------------| +| 4 | 10.79 | 439 | 40.7 | + +## Résultats du cache de préfixe MLLM + +``` +====================================================================== + MLLM PREFIX CACHE TEST +====================================================================== + Model: mlx-community/Qwen3-VL-4B-Instruct-3bit + Test: Verify KV cache reuse for repeated image/video + prompt combinations + Expected behavior: + - Same image + same prompt → cache HIT + - Same image + different prompt → cache MISS + - Different image + same prompt → cache MISS +---------------------------------------------------------------------- + SETUP: Loading Model +---------------------------------------------------------------------- + Model loaded in 0.11s + +---------------------------------------------------------------------- + SETUP: Creating Test Images +---------------------------------------------------------------------- + Resized: 224x224, 336x336, 512x512, 768x768 + +---------------------------------------------------------------------- + TEST 1: Image Cache - Basic Hit/Miss +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Time | Status + -------+---------------------------+----------+--------+--------+------- + 1a | First image+prompt | MISS | MISS | 0.10ms | ✓ + 1b | Same image+prompt | HIT | HIT | 0.18ms | ✓ + 1c | Different prompt | MISS | MISS | 0.01ms | ✓ + 1d | Return to original | HIT | HIT | 0.18ms | ✓ + +---------------------------------------------------------------------- + TEST 2: Different Images +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Time | Status + -------+---------------------------+----------+--------+--------+------- + 2a | Image A first request | MISS | MISS | 0.01ms | ✓ + 2b | Image B first request | MISS | MISS | 0.01ms | ✓ + 2c | Image A cached | HIT | HIT | 0.13ms | ✓ + +---------------------------------------------------------------------- + TEST 3: Image Resolutions +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Time | Status + -------+-----------------------+----------+--------+--------+------- + 3.1a | 224x224 first | MISS | MISS | 0.01ms | ✓ + 3.1b | 224x224 cached | HIT | HIT | 0.20ms | ✓ + 3.2a | 336x336 first | MISS | MISS | 0.01ms | ✓ + 3.2b | 336x336 cached | HIT | HIT | 0.21ms | ✓ + 3.3a | 512x512 first | MISS | MISS | 0.12ms | ✓ + 3.3b | 512x512 cached | HIT | HIT | 0.20ms | ✓ + 3.4a | 768x768 first | MISS | MISS | 0.12ms | ✓ + 3.4b | 768x768 cached | HIT | HIT | 0.24ms | ✓ +====================================================================== +``` + +## Stratégie de clé de cache + +- **Images :** `hash(image_content) + hash(prompt)` + +Une même image avec le même prompt touchera toujours le cache. Une image différente ou un prompt différent provoquera un cache miss. + +## Conseils de performance + +- Les résolutions plus petites sont traitées plus rapidement (224x224 contre 1920x1080) +- Utilisez la résolution adaptée à votre tâche +- Regroupez les images de taille similaire pour un débit constant + +## Référence des métriques + +| Metric | Description | +|--------|-------------| +| Resolution | Dimensions de l'image (largeur x hauteur) | +| Pixels | Nombre total de pixels | +| Time | Durée de génération | +| Tokens | Tokens de sortie générés | +| Speed | Tokens par seconde (tok/s) | diff --git a/docs/fr/benchmarks/llm.md b/docs/fr/benchmarks/llm.md new file mode 100644 index 000000000..72e8483e7 --- /dev/null +++ b/docs/fr/benchmarks/llm.md @@ -0,0 +1,271 @@ +# Benchmarks LLM + +## Lancer les benchmarks LLM + +```bash +vllm-mlx-bench --model mlx-community/Llama-3.2-1B-Instruct-4bit --prompts 5 --max-tokens 256 +``` + +## Résultats (M4 Max, 128 Go) + +| Modèle | Vitesse de génération | TTFT* | Mémoire | +|--------|-----------------------|-------|---------| +| Qwen3-0.6B-8bit | 402,3 tok/s | 58,6 ms | 0,68 Go | +| Llama-3.2-1B-Instruct-4bit | 463,6 tok/s | 49,2 ms | 0,69 Go | +| Qwen2.5-1.5B-Instruct-4bit | 308,5 tok/s | 86,2 ms | 0,84 Go | +| Llama-3.2-3B-Instruct-4bit | 200,1 tok/s | 81,4 ms | 1,79 Go | +| Qwen3-30B-A3B-4bit | 123,9 tok/s | 126,9 ms | 16,05 Go | +| NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-6Bit | 122,9 tok/s | 72,3 ms | 23,98 Go | + +*TTFT = Time to First Token (latence jusqu'au premier token généré) + +## Résultats (M1 Max, 64 Go) + +| Modèle | Requêtes | Tok. prompt | Tok. générés | Temps total (s) | TTFT moyen (ms) | TPOT moyen (ms) | Vitesse génération (tok/s) | Débit total (tok/s) | +|--------|----------|-------------|--------------|-----------------|-----------------|-----------------|---------------------------|---------------------| +| Qwen3-0.6B-8bit | 5 | 56 | 1280 | 5,66 | 119,0 | 3,97 | 251,9 | 236,1 | + +## Résultats du continuous batching + +| Modèle | Requête unique | Batch (5 req) | Accélération | +|--------|----------------|---------------|--------------| +| Llama-3.2-1B-Instruct-4bit | 299,1 tok/s | 613,0 tok/s | **2,05x** | +| Llama-3.2-3B-Instruct-4bit | 137,6 tok/s | 208,1 tok/s | **1,51x** | +| Qwen3-0.6B-8bit | 328,1 tok/s | 1111,8 tok/s | **3,39x** | +| Qwen3-30B-A3B-4bit | 98,1 tok/s | 233,3 tok/s | **2,38x** | +| Qwen2.5-1.5B-Instruct-4bit | 196,9 tok/s | 322,2 tok/s | **1,64x** | + +*Le batching de 5 requêtes simultanées apporte une amélioration du throughput de 1,5 à 3x.* + +### Continuous batching (M1 Max, 64 Go) + +| Requêtes | Tokens totaux | Temps total (s) | Throughput (tok/s) | Requêtes/sec | +|----------|---------------|-----------------|--------------------|--------------| +| 5 | 315 | 0,64 | 492,5 | 7,82 | + +## Performances en streaming + +| Modèle | TTFT | Vitesse de génération | +|--------|------|-----------------------| +| Llama-3.2-1B-Instruct-4bit | ~4,6 ms | 218,9 tok/s | +| Llama-3.2-3B-Instruct-4bit | ~10,7 ms | 93,6 tok/s | +| Qwen3-0.6B-8bit | ~3,0 ms | 328,5 tok/s | +| Qwen3-30B-A3B-4bit | ~10,2 ms | 98,4 tok/s | +| Qwen2.5-1.5B-Instruct-4bit | ~7,1 ms | 140,3 tok/s | + +### Détokeniseur en streaming (M1 Max, 64 Go) + +`vllm-mlx bench-detok` : + +| Tokens | Itérations | Temps naïf | Temps streaming | Accélération | +|--------|------------|------------|-----------------|--------------| +| 742 | 5 | 1,69 ms | 0,71 ms | 2,39x | + +`examples/benchmark_detokenizer.py` : + +| Séquence | Tokens | decode() | Streaming | Accélération | +|----------|--------|----------|-----------|--------------| +| Courte | 8 | 0,029 ms | 0,028 ms | 1,04x | +| Moyenne | 103 | 0,206 ms | 0,129 ms | 1,59x | +| Longue | 511 | 1,040 ms | 0,502 ms | 2,07x | +| 1K | 1191 | 2,446 ms | 1,178 ms | 2,08x | +| 2K | 2381 | 4,949 ms | 2,356 ms | 2,10x | +| 4K | 4761 | 9,887 ms | 5,398 ms | 1,83x | + +Accélération moyenne : 1,79x + +## Résultats du prefix cache + +### Prefix cache (M4 Max, 128 Go) + +``` +====================================================================== + LLM PREFIX CACHE TEST +====================================================================== + Model: mlx-community/Qwen3-0.6B-8bit + Expected behavior: + - Same prompt → cache HIT + - Different prompt → cache MISS +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Status + -------+---------------------+----------+--------+------- + 1a | First request | MISS | MISS | ✓ + 1b | Same prompt | HIT | HIT | ✓ + 1c | Different prompt | MISS | MISS | ✓ + 1d | Return to prompt 1 | HIT | HIT | ✓ +====================================================================== +``` + +### Prefix cache (M1 Max, 64 Go) + +| Test | Attendu | Réel | Temps | Statut | +|------|---------|------|-------|--------| +| Première requête | MISS | MISS | 203,5 ms | PASS | +| Même prompt | HIT | HIT | 131,6 ms | PASS | +| Prompt différent | MISS ou PREFIX_HIT | PREFIX_HIT (5 tok) | 135,3 ms | PASS | + +Statistiques finales du cache : + +| Hits cache | Misses cache | Taux de hit | Tokens économisés | Accélération avec cache | +|------------|--------------|-------------|-------------------|------------------------| +| 2 | 1 | 66,7 % | 20 | 1,55x | + +## Résultats du paged cache + +*Test : 20 requêtes d'inférence réelles en 2 rounds avec un prompt système partagé d'environ 286 tokens* + +``` +====================================================================== + PAGED KV CACHE - REAL INFERENCE TEST +====================================================================== + +-------------------------------------------------- +Test 1: WITHOUT Paged Cache (2 rounds of 10) +-------------------------------------------------- + Time: 1.47s + Throughput: 681.2 tok/s + Cache hits: 0 + Tokens saved: 0 + +-------------------------------------------------- +Test 2: WITH Paged Cache (2 rounds of 10) +-------------------------------------------------- + Time: 1.31s + Throughput: 765.8 tok/s + + Paged Cache Stats: + Blocks allocated: 25 + Shared blocks: 4 + Cache hits: 10 + Tokens saved: 2560 + +================================================== +SUMMARY +================================================== + Without paged cache: 681.2 tok/s + With paged cache: 765.8 tok/s + + Speedup: 1.12x + Cache hits: 10 (all Round 2 requests) + Tokens saved: 2,560 (~256 tokens × 10 requests) +================================================== +``` + +### KV cache paginé (M1 Max, 64 Go) + +Benchmark d'inférence (20 requêtes) : + +| Mode | Temps (s) | Throughput (tok/s) | +|------|-----------|--------------------| +| Sans paged cache | 3,43 | 291,8 | +| Avec paged cache | 3,42 | 292,2 | + +| Accélération | Blocs alloués | Blocs partagés | Hits cache | Tokens économisés | +|--------------|---------------|----------------|------------|-------------------| +| 1,00x | 45 | 4 | 10 | 2560 | + +Inférence concurrente réelle (20 requêtes) : + +| Mode | Temps (s) | Throughput (tok/s) | +|------|-----------|--------------------| +| Sans paged cache | 4,32 | 231,7 | +| Avec paged cache | 4,35 | 229,7 | + +| Accélération | Blocs alloués | Blocs partagés | Hits cache | Tokens économisés | +|--------------|---------------|----------------|------------|-------------------| +| 0,99x | 49 | 8 | 10 | 5120 | + +Démonstration des économies mémoire : + +| Scénario | Économies mémoire | +|----------|-------------------| +| Prompts système partagés | 70,8 % | +| Efficacité mémoire concurrente | 83,5 % | +| Branches avec partage de préfixe | 38,5 % | + +## Analyse du détokeniseur en streaming + +*Investigation phase 9.1 : `BPEStreamingDetokenizer` de mlx-lm vs `tokenizer.decode()` naïf* + +### Contexte + +L'approche naïve appelle `decode([token])` pour chaque token. En théorie, les détokeniseurs en streaming offrent une complexité O(T) contre O(T²) pour le décodage naïf. + +### Résultats du benchmark isolé + +```bash +vllm-mlx bench-detok +``` + +En réutilisant la même instance de détokeniseur (avec `reset()` entre les utilisations) : + +| Séquence | Tokens | decode() naïf | Streaming | Accélération | +|----------|--------|---------------|-----------|--------------| +| Courte | 8 | 0,020 ms | 0,019 ms | 1,05x | +| Moyenne | 103 | 0,155 ms | 0,097 ms | 1,59x | +| Longue | 511 | 0,752 ms | 0,371 ms | **2,03x** | +| 1K tokens | 1191 | 1,743 ms | 0,833 ms | **2,09x** | +| 2K tokens | 2381 | 3,493 ms | 1,737 ms | **2,01x** | + +### Constat critique : surcoût de création d'instance + +La création d'une nouvelle instance de `BPEStreamingDetokenizer` est **extrêmement coûteuse** : + +``` +100 tokenizer.detokenizer calls: 5.266s (52.7ms each!) +``` + +Cela signifie que créer un nouveau détokeniseur par requête ajoute **environ 52 ms de surcoût**, annulant tout bénéfice. + +### Impact en conditions réelles + +Intégré dans le scheduler (un détokeniseur par requête) : + +| Métrique | decode() naïf | Streaming (nouvelle instance) | +|----------|---------------|-------------------------------| +| Throughput (20 req) | 681 tok/s | 275 tok/s | +| Impact | - | **-60 % plus lent** | + +### Conclusion + +Le détokeniseur en streaming n'est **pas viable actuellement** pour un usage par requête, en raison du coût de création d'instance. L'approche naïve `decode([token])` reste plus rapide en pratique. + +**Optimisation future** : pré-créer un pool d'instances de détokeniseur au démarrage et les réutiliser entre les requêtes. + +## Référence des métriques + +| Métrique | Description | +|----------|-------------| +| **TTFT** | Time to First Token - latence jusqu'à ce que le modèle commence à répondre (ms) | +| **TPOT** | Time Per Output Token - temps entre chaque token généré (ms/token) | +| **Generation TPS** | Tokens de sortie par seconde (tok/s) | +| **Processing TPS** | Tokens d'entrée/prompt traités par seconde (tok/s) | +| **End-to-End Latency** | Temps total de la requête à la réponse complète | +| **Total Throughput** | Tokens totaux (entrée + sortie) par seconde | + +## Lancer les benchmarks + +```bash +# Benchmark de base +vllm-mlx-bench --model mlx-community/Qwen3-0.6B-8bit + +# Avec davantage de prompts +vllm-mlx-bench --model mlx-community/Qwen3-0.6B-8bit --prompts 10 + +# Sauvegarder les résultats +vllm-mlx-bench --model mlx-community/Qwen3-0.6B-8bit --output results.json + +# Test de continuous batching +python tests/test_continuous_batching.py + +# Test de prefix cache +python tests/test_prefix_cache.py + +# Test de paged cache +python tests/test_paged_cache_real_inference.py + +# Benchmark du détokeniseur en streaming +vllm-mlx bench-detok +vllm-mlx bench-detok mlx-community/Llama-3.2-1B-Instruct-4bit --iterations 5 +``` diff --git a/docs/fr/benchmarks/video.md b/docs/fr/benchmarks/video.md new file mode 100644 index 000000000..4bba02ce9 --- /dev/null +++ b/docs/fr/benchmarks/video.md @@ -0,0 +1,128 @@ +# Benchmarks vidéo + +## Lancer les benchmarks vidéo + +```bash +# Benchmark complet (10 configurations, 2-64 frames) +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video + +# Benchmark rapide (3 nombres de frames) +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video --quick + +# Vidéo personnalisée +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video --video-url https://example.com/video.mp4 +``` + +## Résultats - Qwen3-VL-8B-Instruct-4bit (M4 Max, 128GB) + +| Configuration | Frames | Time | Tokens | Speed | Memory | +|---------------|--------|------|--------|-------|--------| +| 2 frames @ 0.5fps | 2 | 4.48s | 256 | 57.1 tok/s | 6.4 GB | +| 4 frames @ 1fps | 4 | 4.65s | 256 | 55.0 tok/s | 6.4 GB | +| 6 frames @ 1fps | 6 | 5.15s | 197 | 38.2 tok/s | 6.6 GB | +| 8 frames @ 2fps | 8 | 6.45s | 240 | 37.2 tok/s | 6.8 GB | +| 12 frames @ 2fps | 12 | 8.73s | 256 | 29.3 tok/s | 7.1 GB | +| 16 frames @ 2fps | 16 | 10.96s | 256 | 23.4 tok/s | 7.6 GB | +| 24 frames @ 4fps | 24 | 14.95s | 226 | 15.1 tok/s | 8.4 GB | +| 32 frames @ 4fps | 32 | 20.00s | 256 | 12.8 tok/s | 9.2 GB | +| 48 frames @ 8fps | 48 | 31.11s | 246 | 7.9 tok/s | 11.1 GB | +| 64 frames @ 8fps | 64 | 59.81s | 256 | 4.3 tok/s | 12.9 GB | + +**Résumé :** Le plus rapide à 2 frames (57.1 tok/s), le plus lent à 64 frames (4.3 tok/s). La mémoire varie de 6.4 GB à 12.9 GB. + +> **Note :** 96 frames et plus provoque un timeout GPU sur la plupart des machines en raison des limites de mémoire et de calcul. + +## Résultats - Qwen3-VL-8B-Instruct-4bit (M1 Max, 64GB) + +| Configuration | Frames | FPS | Time | Tokens | Speed | +|---------------|--------|-----|------|--------|-------| +| 4 frames @ 1fps | 4 | 1.0 | 8.84s | 256 | 29.0 tok/s | +| 8 frames @ 2fps | 8 | 2.0 | 13.05s | 256 | 19.6 tok/s | +| 16 frames @ 2fps | 16 | 2.0 | 21.60s | 256 | 11.9 tok/s | + +| Configs | Total Time (s) | Total Tokens | Aggregate Tok/s | +|---------|-----------------|--------------|-----------------| +| 3 | 43.48 | 768 | 17.7 | + +## Résultats - Qwen3-VL-4B-Instruct-3bit (M1 Max, 64GB) + +| Configuration | Frames | FPS | Time | Tokens | Speed | +|---------------|--------|-----|------|--------|-------| +| 4 frames @ 1fps | 4 | 1.0 | 5.09s | 150 | 29.5 tok/s | +| 8 frames @ 2fps | 8 | 2.0 | 8.36s | 150 | 17.9 tok/s | +| 16 frames @ 2fps | 16 | 2.0 | 15.21s | 150 | 9.9 tok/s | + +| Configs | Total Time (s) | Total Tokens | Aggregate Tok/s | +|---------|-----------------|--------------|-----------------| +| 3 | 28.66 | 450 | 15.7 | + +## Résultats du cache vidéo + +``` +---------------------------------------------------------------------- + TEST 4: Video Cache - fps/max_frames in Cache Key +---------------------------------------------------------------------- + Config: fps=2.0, max_frames=16 + + Results: + Step | Description | Expected | Actual | Time | Status + -------+---------------------------+----------+--------+--------+------- + 4a | Video first request | MISS | MISS | 0.03ms | ✓ + 4b | Same video+params | HIT | HIT | 0.14ms | ✓ + 4c | Different fps (4.0) | MISS | MISS | 0.01ms | ✓ + 4d | Different max_frames (32) | MISS | MISS | 0.01ms | ✓ + 4.0.5a | fps=0.5 first | MISS | MISS | 0.01ms | ✓ + 4.0.5b | fps=0.5 cached | HIT | HIT | 0.14ms | ✓ + 4.1.0a | fps=1.0 first | MISS | MISS | 0.01ms | ✓ + 4.1.0b | fps=1.0 cached | HIT | HIT | 0.14ms | ✓ + 4.2.0a | fps=2.0 first | MISS | MISS | 0.01ms | ✓ + 4.2.0b | fps=2.0 cached | HIT | HIT | 0.14ms | ✓ + 4.4.0a | fps=4.0 first | MISS | MISS | 0.01ms | ✓ + 4.4.0b | fps=4.0 cached | HIT | HIT | 0.14ms | ✓ + +---------------------------------------------------------------------- + TEST 5: Additional Videos +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Time | Status + -------+---------------------------+----------+--------+--------+------- + 5a | Video 1 first | MISS | MISS | 0.01ms | ✓ + 5b | Video 2 first | MISS | MISS | 0.01ms | ✓ + 5c | Video 1 cached | HIT | HIT | 0.13ms | ✓ + 5d | Video 2 cached | HIT | HIT | 0.13ms | ✓ +``` + +## Stratégie de clé de cache + +- **Vidéos :** `hash(video_path) + hash(fps) + hash(max_frames) + hash(prompt)` + +La même vidéo avec les mêmes paramètres fps, max_frames et prompt donnera un HIT dans le cache. La modification de l'un quelconque de ces paramètres provoque un MISS. + +## Conseils de performance + +- Un FPS plus faible accélère le traitement +- Moins de frames réduit l'utilisation mémoire +- 64 frames est le maximum pratique +- 96 frames et plus provoque un timeout GPU + +## Extraction de frames + +| FPS | Vidéo 10s | Vidéo 30s | Vidéo 60s | +|-----|-----------|-----------|-----------| +| 0.5 | 5 frames | 15 frames | 30 frames | +| 1.0 | 10 frames | 30 frames | 60 frames | +| 2.0 | 20 frames | 60 frames | 120 frames* | +| 4.0 | 40 frames | 120 frames* | 240 frames* | + +*Peut atteindre la limite `max_frames` + +## Référence des métriques + +| Metric | Description | +|--------|-------------| +| Configuration | Paramètres FPS et nombre maximum de frames | +| Frames | Nombre réel de frames extraites | +| Time | Temps total de génération | +| Tokens | Tokens de sortie générés | +| Speed | Tokens par seconde (tok/s) | +| Memory | Utilisation de la mémoire GPU | diff --git a/docs/fr/getting-started/installation.md b/docs/fr/getting-started/installation.md new file mode 100644 index 000000000..444970b0e --- /dev/null +++ b/docs/fr/getting-started/installation.md @@ -0,0 +1,89 @@ +# Installation + +## Prérequis + +- macOS sur Apple Silicon (M1/M2/M3/M4) +- Python 3.10+ + +## Installation avec uv (recommandée) + +```bash +git clone https://github.com/waybarrios/vllm-mlx.git +cd vllm-mlx + +uv pip install -e . +``` + +## Installation avec pip + +```bash +git clone https://github.com/waybarrios/vllm-mlx.git +cd vllm-mlx + +pip install -e . +``` + +### Optionnel : support vidéo + +Pour le traitement vidéo avec transformers : + +```bash +pip install -e ".[vision]" +``` + +### Optionnel : support audio (STT/TTS) + +```bash +pip install mlx-audio +``` + +### Optionnel : embeddings + +```bash +pip install mlx-embeddings +``` + +## Ce qui est installé + +- `mlx`, `mlx-lm`, `mlx-vlm` - framework MLX et bibliothèques de modèles +- `transformers`, `tokenizers` - bibliothèques HuggingFace +- `opencv-python` - traitement vidéo +- `gradio` - interface de chat +- `psutil` - surveillance des ressources +- `mlx-audio` (optionnel) - Speech-to-Text et Text-to-Speech +- `mlx-embeddings` (optionnel) - embeddings de texte + +## Vérifier l'installation + +```bash +# Check CLI commands +vllm-mlx --help +vllm-mlx-bench --help +vllm-mlx-chat --help + +# Test with a small model +vllm-mlx-bench --model mlx-community/Llama-3.2-1B-Instruct-4bit --prompts 1 +``` + +## Dépannage + +### MLX introuvable + +Vérifiez que vous êtes sur Apple Silicon : +```bash +uname -m # Should output "arm64" +``` + +### Échec du téléchargement du modèle + +Vérifiez votre connexion internet et vos accès HuggingFace. Certains modèles nécessitent une authentification : +```bash +huggingface-cli login +``` + +### Mémoire insuffisante + +Utilisez un modèle quantifié plus petit : +```bash +vllm-mlx serve mlx-community/Llama-3.2-1B-Instruct-4bit +``` diff --git a/docs/fr/getting-started/quickstart.md b/docs/fr/getting-started/quickstart.md new file mode 100644 index 000000000..b45ce7014 --- /dev/null +++ b/docs/fr/getting-started/quickstart.md @@ -0,0 +1,133 @@ +# Démarrage rapide + +## Option 1 : serveur compatible OpenAI + +Démarrez le serveur : + +```bash +# Simple mode - maximum throughput for single user +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 + +# Continuous batching - for multiple concurrent users +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous-batching +``` + +Utilisation avec le SDK Python OpenAI : + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +response = client.chat.completions.create( + model="mlx-community/Llama-3.2-3B-Instruct-4bit", + messages=[{"role": "user", "content": "Hello!"}], +) +print(response.choices[0].message.content) +``` + +Ou avec curl : + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"model": "default", "messages": [{"role": "user", "content": "Hello!"}]}' +``` + +## Option 2 : API Python directe + +```python +from vllm_mlx.models import MLXLanguageModel + +model = MLXLanguageModel("mlx-community/Llama-3.2-3B-Instruct-4bit") +model.load() + +# Generate text +output = model.generate("What is the capital of France?", max_tokens=100) +print(output.text) + +# Streaming +for chunk in model.stream_generate("Tell me a story"): + print(chunk.text, end="", flush=True) +``` + +## Option 3 : interface de chat Gradio + +```bash +vllm-mlx-chat --served-model-name mlx-community/Llama-3.2-3B-Instruct-4bit +``` + +Ouvre une interface web à l'adresse http://localhost:7860 + +## Modèles multimodaux + +Pour la compréhension d'images et de vidéos, utilisez un modèle VLM : + +```bash +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit --port 8000 +``` + +```python +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}} + ] + }], + max_tokens=256 +) +``` + +## Modèles de raisonnement + +Séparez le processus de réflexion du modèle de la réponse finale : + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 +``` + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What is 17 × 23?"}] +) +print(response.choices[0].message.content) # Final answer +``` + +## Embeddings + +Générez des embeddings textuels pour la recherche sémantique et le RAG : + +```bash +vllm-mlx serve mlx-community/Qwen3-4B-4bit --embedding-model mlx-community/multilingual-e5-small-mlx +``` + +```python +response = client.embeddings.create( + model="mlx-community/multilingual-e5-small-mlx", + input="Hello world" +) +``` + +## Tool Calling + +Activez l'appel de fonctions avec tout modèle compatible : + +```bash +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice --tool-call-parser mistral +``` + +## Étapes suivantes + +- [Server Guide](../guides/server.md) - Configuration complète du serveur +- [Python API](../guides/python-api.md) - Utilisation directe de l'API +- [Multimodal Guide](../guides/multimodal.md) - Images et vidéos +- [Audio Guide](../guides/audio.md) - Speech-to-Text et Text-to-Speech +- [Embeddings Guide](../guides/embeddings.md) - Embeddings textuels +- [Reasoning Models](../guides/reasoning.md) - Modèles de réflexion +- [Tool Calling](../guides/tool-calling.md) - Appel de fonctions +- [Supported Models](../reference/models.md) - Modèles disponibles diff --git a/docs/fr/guides/audio.md b/docs/fr/guides/audio.md new file mode 100644 index 000000000..a8bc08371 --- /dev/null +++ b/docs/fr/guides/audio.md @@ -0,0 +1,524 @@ +# Support Audio + +vllm-mlx prend en charge le traitement audio via [mlx-audio](https://github.com/Blaizzy/mlx-audio), offrant : + +- **STT (Speech-to-Text)** : Whisper, Parakeet +- **TTS (Text-to-Speech)** : Kokoro, Chatterbox, VibeVoice, VoxCPM +- **Traitement audio** : SAM-Audio (séparation vocale) + +## Installation + +```bash +# Support audio de base +pip install mlx-audio>=0.2.9 + +# Required dependencies for TTS +pip install sounddevice soundfile scipy numba tiktoken misaki spacy num2words loguru phonemizer + +# Download spacy English model +python -m spacy download en_core_web_sm + +# For non-English TTS (Spanish, French, etc.), install espeak-ng: +# macOS +brew install espeak-ng + +# Ubuntu/Debian +# sudo apt-get install espeak-ng +``` + +Ou installez toutes les dépendances audio en une seule commande : + +```bash +pip install vllm-mlx[audio] +python -m spacy download en_core_web_sm +brew install espeak-ng # macOS, for non-English languages +``` + +## Démarrage rapide + +### Speech-to-Text (Transcription) + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# Transcribe audio file +with open("audio.mp3", "rb") as f: + transcript = client.audio.transcriptions.create( + model="whisper-large-v3", + file=f, + language="en" # optional + ) +print(transcript.text) +``` + +### Text-to-Speech (Génération) + +```python +# Generate speech +audio = client.audio.speech.create( + model="kokoro", + input="Hello, how are you?", + voice="af_heart", + speed=1.0 +) + +# Save to file +with open("output.wav", "wb") as f: + f.write(audio.content) +``` + +### Séparation vocale (SAM-Audio) + +Isolez une voix du bruit de fond, de la musique ou d'autres sons : + +```python +from vllm_mlx.audio import AudioProcessor + +# Load SAM-Audio model +processor = AudioProcessor("mlx-community/sam-audio-large-fp16") +processor.load() + +# Separate speech from audio +result = processor.separate("meeting_with_music.mp3", description="speech") + +# Save isolated voice and background +processor.save(result.target, "voice_only.wav") +processor.save(result.residual, "background_only.wav") +``` + +**Exemple en ligne de commande :** +```bash +python examples/audio_separation_example.py meeting.mp3 --play +python examples/audio_separation_example.py song.mp3 --description music -o music.wav +``` + +### Démo de séparation de batterie + +Isolez la batterie d'une chanson rock avec SAM-Audio : + +| Audio | Description | Écouter | +|-------|-------------|---------| +| Original | "Get Ready" de David Fesliyan (30s, libre de droits) | [rock_get_ready.mp3](../../../examples/rock_get_ready.mp3) | +| Batterie isolée | Batterie extraite par SAM-Audio | [drums_isolated.wav](../../../examples/drums_isolated.wav) | +| Sans batterie | Piste sans batterie | [rock_no_drums.wav](../../../examples/rock_no_drums.wav) | + +```bash +# Isolate drums from rock song +python examples/audio_separation_example.py examples/rock_get_ready.mp3 \ + --description "drums" \ + --output drums_isolated.wav \ + --background rock_no_drums.wav +``` + +**Performance :** 30 secondes d'audio traitées en environ 20 secondes sur M4 Max. + +## Modèles pris en charge + +### Modèles STT (Speech-to-Text) + +| Modèle | Alias | Langues | Vitesse | Qualité | +|--------|-------|---------|---------|---------| +| `mlx-community/whisper-large-v3-mlx` | `whisper-large-v3` | 99+ | Moyenne | Meilleure | +| `mlx-community/whisper-large-v3-turbo` | `whisper-large-v3-turbo` | 99+ | Rapide | Excellente | +| `mlx-community/whisper-medium-mlx` | `whisper-medium` | 99+ | Rapide | Bonne | +| `mlx-community/whisper-small-mlx` | `whisper-small` | 99+ | Très rapide | Correcte | +| `mlx-community/parakeet-tdt-0.6b-v2` | `parakeet` | Anglais | La plus rapide | Excellente | +| `mlx-community/parakeet-tdt-0.6b-v3` | `parakeet-v3` | Anglais | La plus rapide | Meilleure | + +**Recommandations :** +- Multilingue : `whisper-large-v3` +- Anglais uniquement : `parakeet` (3 fois plus rapide) + +### Modèles TTS (Text-to-Speech) + +#### Kokoro (Rapide, Léger) - Recommandé + +| Modèle | Alias | Taille | Langues | +|--------|-------|--------|---------| +| `mlx-community/Kokoro-82M-bf16` | `kokoro` | 82M | EN, ES, FR, JA, ZH, HI, IT, PT | +| `mlx-community/Kokoro-82M-4bit` | `kokoro-4bit` | 82M | EN, ES, FR, JA, ZH, HI, IT, PT | + +**Voix (11) :** +- Femme américaine : `af_heart`, `af_bella`, `af_nicole`, `af_sarah`, `af_sky` +- Homme américain : `am_adam`, `am_michael` +- Femme britannique : `bf_emma`, `bf_isabella` +- Homme britannique : `bm_george`, `bm_lewis` + +**Codes de langue :** +| Code | Langue | Code | Langue | +|------|--------|------|--------| +| `a` / `en` | English (US) | `e` / `es` | Español | +| `b` / `en-gb` | English (UK) | `f` / `fr` | Français | +| `j` / `ja` | 日本語 | `z` / `zh` | 中文 | +| `i` / `it` | Italiano | `p` / `pt` | Português | +| `h` / `hi` | हिन्दी | | | + +#### Chatterbox (Multilingue, Expressif) + +| Modèle | Alias | Taille | Langues | +|--------|-------|--------|---------| +| `mlx-community/chatterbox-turbo-fp16` | `chatterbox` | 134M | 15+ langues | +| `mlx-community/chatterbox-turbo-4bit` | `chatterbox-4bit` | 134M | 15+ langues | + +**Langues prises en charge :** EN, ES, FR, DE, IT, PT, RU, JA, ZH, KO, AR, HI, NL, PL, TR + +#### VibeVoice (Temps réel) + +| Modèle | Alias | Taille | Cas d'usage | +|--------|-------|--------|-------------| +| `mlx-community/VibeVoice-Realtime-0.5B-4bit` | `vibevoice` | 200M | Faible latence, anglais | + +#### VoxCPM (Chinois/Anglais) + +| Modèle | Alias | Taille | Langues | +|--------|-------|--------|---------| +| `mlx-community/VoxCPM1.5` | `voxcpm` | 0.9B | ZH, EN | +| `mlx-community/VoxCPM1.5-4bit` | `voxcpm-4bit` | 200M | ZH, EN | + +### Modèles de traitement audio + +#### SAM-Audio (Séparation vocale) + +| Modèle | Taille | Cas d'usage | +|--------|--------|-------------| +| `mlx-community/sam-audio-large-fp16` | 3B | Meilleure qualité | +| `mlx-community/sam-audio-large` | 3B | Standard | +| `mlx-community/sam-audio-small-fp16` | 0.6B | Rapide | +| `mlx-community/sam-audio-small` | 0.6B | Léger | + +## Référence API + +### POST /v1/audio/transcriptions + +Transcrit un fichier audio en texte (compatible API OpenAI Whisper). + +**Paramètres :** +- `file` : Fichier audio (mp3, wav, m4a, webm) +- `model` : Nom ou alias du modèle +- `language` : Code de langue (optionnel, détection automatique) +- `response_format` : `json` ou `text` + +**Limites :** +- Taille maximale par défaut : 25 MiB +- Modifiable avec `--max-audio-upload-mb` + +**Exemple :** +```bash +curl http://localhost:8000/v1/audio/transcriptions \ + -F file=@audio.mp3 \ + -F model=whisper-large-v3 +``` + +### POST /v1/audio/speech + +Génère de la parole à partir de texte (compatible API OpenAI TTS). + +**Paramètres :** +- `model` : Nom ou alias du modèle +- `input` : Texte à synthétiser +- `voice` : Identifiant de la voix +- `speed` : Vitesse de parole (0,5 à 2,0) +- `response_format` : `wav`, `mp3` + +**Limites :** +- Nombre de caractères maximal par défaut : 4096 +- Modifiable avec `--max-tts-input-chars` + +**Exemple :** +```bash +curl http://localhost:8000/v1/audio/speech \ + -d '{"model": "kokoro", "input": "Hello world", "voice": "af_heart"}' \ + -H "Content-Type: application/json" \ + --output speech.wav +``` + +### GET /v1/audio/voices + +Liste les voix disponibles pour un modèle. + +**Exemple :** +```bash +curl http://localhost:8000/v1/audio/voices?model=kokoro +``` + +## Exemples en ligne de commande + +### Transcription en direct / Sous-titres en temps réel + +Transcription STT en temps réel depuis votre microphone : + +```bash +# Closed captions with whisper-large-v3 (best quality) +python examples/closed_captions.py --language es --chunk 5 + +# Faster model for lower latency +python examples/closed_captions.py --language en --model whisper-turbo --chunk 3 + +# Basic mic transcription (record then transcribe) +python examples/mic_transcribe.py --language es + +# Real-time chunked transcription +python examples/mic_realtime.py --language es --chunk 3 + +# Live transcription with voice activity detection +python examples/mic_live.py --language es +``` + +**Prérequis :** +```bash +pip install sounddevice soundfile numpy +``` + +### TTS de base + +```bash +# Simple TTS example +python examples/tts_example.py "Hello, how are you?" --play + +# With different voice +python examples/tts_example.py "Hello!" --voice am_michael --play + +# Save to file +python examples/tts_example.py "Welcome to the demo" -o greeting.wav + +# List available voices +python examples/tts_example.py --list-voices +``` + +### TTS multilingue + +```bash +# English (auto-selects best model) +python examples/tts_multilingual.py "Hello world" --play + +# Spanish +python examples/tts_multilingual.py "Hola mundo" --lang es --play + +# French +python examples/tts_multilingual.py "Bonjour le monde" --lang fr --play + +# Japanese +python examples/tts_multilingual.py "こんにちは" --lang ja --play + +# Chinese +python examples/tts_multilingual.py "你好世界" --lang zh --play + +# Use specific model +python examples/tts_multilingual.py "Hello" --model chatterbox --play + +# List all models +python examples/tts_multilingual.py --list-models + +# List all languages +python examples/tts_multilingual.py --list-languages +``` + +### Exemples d'assistant vocal professionnel + +Exemples vocaux prégénérés avec des **voix natives** pour des cas d'usage professionnels courants : + +| Langue | Voix | Message | Écouter | +|--------|------|---------|---------| +| Anglais | af_heart | "Welcome to First National Bank. How may I assist you today?" | [assistant_bank_en.wav](../../../examples/assistant_bank_en.wav) | +| Espagnol | ef_dora | "Gracias por llamar a servicio al cliente. Un agente le atenderá pronto." | [assistant_service_es.wav](../../../examples/assistant_service_es.wav) | +| Français | ff_siwis | "Bienvenue. Votre appel est important pour nous." | [assistant_callcenter_fr.wav](../../../examples/assistant_callcenter_fr.wav) | +| Chinois | zf_xiaobei | "欢迎致电技术支持中心。我们将竭诚为您服务。" | [assistant_support_zh.wav](../../../examples/assistant_support_zh.wav) | + +**Générez vos propres exemples avec des voix natives :** +```bash +# English - Bank assistant (native voice: af_heart) +python -m mlx_audio.tts.generate --model mlx-community/Kokoro-82M-bf16 \ + --text "Welcome to First National Bank. How may I assist you today?" \ + --voice af_heart --lang_code a --file_prefix assistant_bank_en + +# Spanish - Customer service (native voice: ef_dora) +python -m mlx_audio.tts.generate --model mlx-community/Kokoro-82M-bf16 \ + --text "Gracias por llamar a servicio al cliente. Un agente le atendera pronto." \ + --voice ef_dora --lang_code e --file_prefix assistant_service_es + +# French - Call center (native voice: ff_siwis) +python -m mlx_audio.tts.generate --model mlx-community/Kokoro-82M-bf16 \ + --text "Bienvenue. Votre appel est important pour nous." \ + --voice ff_siwis --lang_code f --file_prefix assistant_callcenter_fr + +# Chinese - Tech support (native voice: zf_xiaobei) +python -m mlx_audio.tts.generate --model mlx-community/Kokoro-82M-bf16 \ + --text "欢迎致电技术支持中心。我们将竭诚为您服务。" \ + --voice zf_xiaobei --lang_code z --file_prefix assistant_support_zh +``` + +### Référence des voix natives + +| Langue | Code | Voix | +|--------|------|------| +| English (US) | `a` | af_heart, af_bella, af_nicole, am_adam, am_michael | +| English (UK) | `b` | bf_emma, bf_isabella, bm_george, bm_lewis | +| Espagnol | `e` | ef_dora, em_alex, em_santa | +| Français | `f` | ff_siwis | +| Chinois | `z` | zf_xiaobei, zf_xiaoni, zf_xiaoxiao, zm_yunjian, zm_yunxi | +| Japonais | `j` | jf_alpha, jf_gongitsune, jm_kumo | +| Italien | `i` | if_sara, im_nicola | +| Portugais | `p` | pf_dora, pm_alex | +| Hindi | `h` | hf_alpha, hf_beta, hm_omega | + +## API Python + +### Utilisation directe (sans serveur) + +```python +from vllm_mlx.audio import STTEngine, TTSEngine, AudioProcessor + +# Speech-to-Text +stt = STTEngine("mlx-community/whisper-large-v3-mlx") +stt.load() +result = stt.transcribe("audio.mp3") +print(result.text) + +# Text-to-Speech +tts = TTSEngine("mlx-community/Kokoro-82M-bf16") +tts.load() +audio = tts.generate("Hello world", voice="af_heart") +tts.save(audio, "output.wav") + +# Voice Separation +processor = AudioProcessor("mlx-community/sam-audio-large-fp16") +processor.load() +result = processor.separate("mixed_audio.mp3", description="speech") +processor.save(result.target, "voice_only.wav") +processor.save(result.residual, "background.wav") +``` + +### Fonctions utilitaires + +```python +from vllm_mlx.audio import transcribe_audio, generate_speech, separate_voice + +# Quick transcription +result = transcribe_audio("audio.mp3") +print(result.text) + +# Quick TTS +audio = generate_speech("Hello world", voice="af_heart") + +# Quick voice separation +voice, background = separate_voice("mixed.mp3") +``` + +## Audio dans le chat + +Incluez de l'audio dans les messages du chat (transcrit automatiquement) : + +```python +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "Summarize this audio"}, + {"type": "audio_url", "audio_url": {"url": "file://meeting.mp3"}} + ] + }] +) +``` + +## Benchmarks + +Testé sur Apple M2 Max (32 Go). + +### Benchmarks TTS (Kokoro-82M-bf16) + +| Longueur du texte | Durée audio | Temps de génération | RTF | Caractères/s | +|-------------------|-------------|---------------------|-----|--------------| +| 25 caractères | 1,95 s | 0,43 s | 4,6x | 58,5 | +| 88 caractères | 6,00 s | 0,32 s | 18,6x | 272,4 | +| 117 caractères | 7,92 s | 0,27 s | 29,0x | 427,4 | + +**Résumé :** +- Temps de chargement du modèle : environ 1,0 s +- RTF moyen : **17,4x** (17 fois plus rapide que le temps réel) +- Caractères/s moyens : **252,8** + +### Benchmarks STT + +| Modèle | Temps de chargement | Transcription (6 s audio) | RTF | +|--------|---------------------|---------------------------|-----| +| whisper-small | 0,25 s | 0,20 s | 30,2x | +| whisper-medium | 18,1 s | 0,38 s | 15,5x | +| whisper-large-v3 | environ 30 s | environ 0,6 s | environ 10x | +| parakeet | environ 0,5 s | environ 0,15 s | environ 40x | + +**Notes :** +- Le RTF (Real-Time Factor) indique combien de fois plus rapide que le temps réel +- Le premier chargement inclut le téléchargement du modèle depuis HuggingFace +- Les chargements suivants utilisent les modèles mis en cache + +### Recommandations par cas d'usage + +| Cas d'usage | Modèle recommandé | Pourquoi | +|-------------|------------------|----------| +| STT anglais rapide | `parakeet` | RTF 40x, faible consommation mémoire | +| STT multilingue | `whisper-large-v3` | 99+ langues | +| STT faible latence | `whisper-small` | RTF 30x, chargement rapide | +| TTS général | `kokoro` | RTF 17x, bonne qualité | +| TTS faible mémoire | `kokoro-4bit` | Quantification 4 bits | + +## Conseils de performance + +1. **Utilisez Parakeet pour l'anglais** : 40 fois plus rapide que le temps réel +2. **Utilisez les modèles 4 bits** pour réduire la consommation mémoire +3. **Utilisez SAM-Audio small** pour une séparation vocale plus rapide +4. **Mettez les modèles en cache** : les moteurs sont chargés à la demande et mis en cache +5. **Pré-téléchargez les modèles** pour éviter la latence au premier démarrage + +## Dépannage + +### mlx-audio non installé +``` +pip install mlx-audio>=0.2.9 +``` + +### Téléchargement du modèle lent +Les modèles sont téléchargés depuis HuggingFace lors de la première utilisation. Utilisez `huggingface-cli download` pour les pré-télécharger : +```bash +huggingface-cli download mlx-community/whisper-large-v3-mlx +huggingface-cli download mlx-community/Kokoro-82M-bf16 +``` + +### Mémoire insuffisante +Utilisez des modèles plus petits ou des versions quantifiées en 4 bits : +- `whisper-small-mlx` plutôt que `whisper-large-v3-mlx` +- `Kokoro-82M-4bit` plutôt que `Kokoro-82M-bf16` +- `sam-audio-small` plutôt que `sam-audio-large` + +### Bug multilingue Kokoro (mlx-audio 0.2.9) + +Si vous obtenez `ValueError: too many values to unpack` en utilisant des langues autres que l'anglais (espagnol, chinois, japonais, etc.) avec Kokoro, appliquez ce correctif : + +```python +# Fix for mlx_audio/tts/models/kokoro/pipeline.py line 443 +# Change: +# ps, _ = self.g2p(chunk) +# To: +g2p_result = self.g2p(chunk) +ps = g2p_result[0] if isinstance(g2p_result, tuple) else g2p_result +``` + +**Correctif en une ligne :** +```bash +python -c " +import os +path = os.path.join(os.path.dirname(__import__('mlx_audio').__file__), 'tts/models/kokoro/pipeline.py') +with open(path, 'r') as f: content = f.read() +old = ' ps, _ = self.g2p(chunk)' +new = ''' # Fix: handle both tuple (en) and string (zh/ja/es) returns from g2p + g2p_result = self.g2p(chunk) + ps = g2p_result[0] if isinstance(g2p_result, tuple) else g2p_result''' +if old in content: + with open(path, 'w') as f: f.write(content.replace(old, new)) + print('Fix applied!') +" +``` + +Ce bug survient car le g2p anglais retourne un tuple `(phonemes, tokens)` tandis que les autres langues retournent uniquement une chaîne de caractères. diff --git a/docs/fr/guides/continuous-batching.md b/docs/fr/guides/continuous-batching.md new file mode 100644 index 000000000..fc9982ba3 --- /dev/null +++ b/docs/fr/guides/continuous-batching.md @@ -0,0 +1,174 @@ +# Continuous Batching + +Le continuous batching permet d'augmenter le throughput lors du traitement de plusieurs utilisateurs simultanés. + +## Activer le Continuous Batching + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit --continuous-batching +``` + +## Avec le Paged Cache + +Pour un partage mémoire efficace des préfixes : + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit --continuous-batching --use-paged-cache +``` + +## Fonctionnement + +### Mode simple (par défaut) +- Une seule requête à la fois +- Throughput maximal pour un utilisateur unique +- Aucune surcharge liée au batching + +### Mode Continuous Batching +- Plusieurs requêtes traitées simultanément +- Meilleur throughput pour les utilisateurs concurrents +- Légère surcharge par requête + +### Paged Cache +- Le KV cache est stocké en blocs de taille fixe +- Les prompts système identiques partagent les mêmes blocs +- Économies mémoire : 80 % et plus pour 10 utilisateurs simultanés ou davantage + +## Résultats de performance + +**Résultats du Continuous Batching (M4 Max, 128 Go) :** + +| Modèle | Requête unique | Batch (5 req) | Accélération | +|--------|----------------|---------------|--------------| +| Llama-3.2-1B-Instruct-4bit | 299.1 tok/s | 613.0 tok/s | **2.05x** | +| Llama-3.2-3B-Instruct-4bit | 137.6 tok/s | 208.1 tok/s | **1.51x** | +| Qwen3-0.6B-8bit | 328.1 tok/s | 1111.8 tok/s | **3.39x** | +| Qwen3-30B-A3B-4bit | 98.1 tok/s | 233.3 tok/s | **2.38x** | +| Qwen2.5-1.5B-Instruct-4bit | 196.9 tok/s | 322.2 tok/s | **1.64x** | + +*Le batching de 5 requêtes simultanées améliore le throughput d'un facteur 1,5 à 3.* + +## Performance en Streaming + +**Performance en streaming (M4 Max, 128 Go) :** + +| Modèle | TTFT | Vitesse de génération | +|--------|------|-----------------------| +| Llama-3.2-1B-Instruct-4bit | ~4.6 ms | 218.9 tok/s | +| Llama-3.2-3B-Instruct-4bit | ~10.7 ms | 93.6 tok/s | +| Qwen3-0.6B-8bit | ~3.0 ms | 328.5 tok/s | +| Qwen3-30B-A3B-4bit | ~10.2 ms | 98.4 tok/s | +| Qwen2.5-1.5B-Instruct-4bit | ~7.1 ms | 140.3 tok/s | + +*TTFT = Time to First Token* + +## Configuration du Streaming + +Contrôlez la cadence d'envoi des tokens avec `--stream-interval` : + +```bash +# Chaque token (le plus fluide) +vllm-mlx serve model --continuous-batching --stream-interval 1 + +# Tokens groupés (préférable pour les connexions à latence élevée) +vllm-mlx serve model --continuous-batching --stream-interval 5 +``` + +| Valeur | Comportement | +|--------|-------------| +| `1` | Envoie chaque token immédiatement | +| `2-5` | Regroupe les tokens avant l'envoi | +| `10+` | Throughput maximal, sortie plus fragmentée | + +## Gestion de la mémoire + +Pour les grands modèles, le prefix cache peut consommer une quantité significative de mémoire. Le cache adaptatif la gère automatiquement : + +```bash +# Détection automatique (utilise 20 % de la RAM disponible) +vllm-mlx serve model --continuous-batching + +# Limite explicite +vllm-mlx serve model --continuous-batching --cache-memory-mb 2048 + +# Pourcentage personnalisé +vllm-mlx serve model --continuous-batching --cache-memory-percent 0.10 +``` + +| Option | Description | +|--------|-------------| +| `--cache-memory-mb` | Définit une limite explicite en Mo | +| `--cache-memory-percent` | Fraction de la RAM disponible (par défaut : 0,20) | +| `--no-memory-aware-cache` | Utilise le cache historique basé sur le nombre d'entrées | + +## Prefix Cache + +Le prefix caching réutilise le KV cache pour les prompts répétés. + +### Fonctionnement + +``` +User 1: System prompt (500 tokens) → Creates 8 blocks +User 2: Same system prompt → Shares 8 blocks (ref_count++) +User N: Same system prompt → Shares 8 blocks (ref_count++) + +Memory savings: 80%+ for 10+ concurrent users +``` + +### Stratégie de clé de cache + +- **LLM** : `hash(prompt)` +- **Images** : `hash(image_content) + hash(prompt)` +- **Vidéos** : `hash(video_path) + hash(fps) + hash(max_frames) + hash(prompt)` + +### Tester le Prefix Cache + +```bash +python tests/test_prefix_cache.py +``` + +``` +====================================================================== + LLM PREFIX CACHE TEST +====================================================================== + Model: mlx-community/Qwen3-0.6B-8bit + Expected behavior: + - Same prompt → cache HIT + - Different prompt → cache MISS or PREFIX_HIT (shared template tokens) +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Status + -------+---------------------+----------+--------+------- + 1a | First request | MISS | MISS | PASS + 1b | Same prompt | HIT | HIT | PASS + 1c | Different prompt | MISS | MISS | PASS + 1d | Return to prompt 1 | HIT | HIT | PASS +====================================================================== +``` + +## Exécuter les benchmarks + +```bash +# Benchmark du continuous batching +python tests/test_continuous_batching.py + +# Test du prefix cache +python tests/test_prefix_cache.py +``` + +## Quand l'utiliser + +| Scénario | Mode | +|----------|------| +| Utilisateur unique, vitesse maximale | Simple (par défaut) | +| Plusieurs utilisateurs simultanés | `--continuous-batching` | +| Grands modèles (7B et plus) | `--continuous-batching --cache-memory-mb 2048` | +| Production avec prompts partagés | `--continuous-batching --use-paged-cache` | + +## Configuration en production + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching \ + --use-paged-cache \ + --port 8000 +``` diff --git a/docs/fr/guides/embeddings.md b/docs/fr/guides/embeddings.md new file mode 100644 index 000000000..3b22f6bbc --- /dev/null +++ b/docs/fr/guides/embeddings.md @@ -0,0 +1,150 @@ +# Embeddings + +vllm-mlx prend en charge les embeddings de texte via [mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings), en exposant un point d'accès `/v1/embeddings` compatible OpenAI. + +## Installation + +```bash +pip install mlx-embeddings>=0.0.5 +``` + +## Démarrage rapide + +### Lancer le serveur avec un modèle d'embeddings + +```bash +# Précharger un modèle d'embeddings spécifique au démarrage +vllm-mlx serve my-llm-model --embedding-model mlx-community/all-MiniLM-L6-v2-4bit +``` + +Si vous n'utilisez pas `--embedding-model`, le modèle d'embeddings est chargé à la demande lors de la première requête, mais uniquement parmi les modèles autorisés par défaut. + +### Générer des embeddings avec le SDK OpenAI + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# Texte unique +response = client.embeddings.create( + model="mlx-community/all-MiniLM-L6-v2-4bit", + input="Hello world" +) +print(response.data[0].embedding[:5]) # First 5 dimensions + +# Lot de textes +response = client.embeddings.create( + model="mlx-community/all-MiniLM-L6-v2-4bit", + input=[ + "I love machine learning", + "Deep learning is fascinating", + "Natural language processing rocks" + ] +) +for item in response.data: + print(f"Text {item.index}: {len(item.embedding)} dimensions") +``` + +### Utilisation avec curl + +```bash +curl http://localhost:8000/v1/embeddings \ + -H "Content-Type: application/json" \ + -d '{ + "model": "mlx-community/all-MiniLM-L6-v2-4bit", + "input": ["Hello world", "How are you?"] + }' +``` + +## Modèles pris en charge + +Modèles disponibles à la demande : + +| Modèle | Cas d'usage | Taille | +|--------|-------------|--------| +| `mlx-community/all-MiniLM-L6-v2-4bit` | Rapide et compact | Small | +| `mlx-community/embeddinggemma-300m-6bit` | Haute qualité | 300M | +| `mlx-community/bge-large-en-v1.5-4bit` | Optimal pour l'anglais | Large | +| `mlx-community/multilingual-e5-small-mlx` | Récupération multilingue | Small | +| `mlx-community/multilingual-e5-large-mlx` | Récupération multilingue | Large | +| `mlx-community/bert-base-uncased-mlx` | Référence BERT générale | Base | +| `mlx-community/ModernBERT-base-mlx` | Référence ModernBERT | Base | + +Les autres modèles d'embeddings nécessitent l'option `--embedding-model` au démarrage du serveur. + +## Gestion des modèles + +### Chargement à la demande + +Par défaut, le modèle d'embeddings est chargé lors de la première requête sur `/v1/embeddings`. Vous pouvez alterner entre les modèles autorisés listés ci-dessus ; le modèle précédent est déchargé automatiquement. + +### Préchargement au démarrage + +Utilisez `--embedding-model` pour charger un modèle au démarrage. Lorsque cette option est définie, seul ce modèle peut être utilisé pour les embeddings : + +```bash +vllm-mlx serve my-llm-model --embedding-model mlx-community/all-MiniLM-L6-v2-4bit +``` + +Toute requête utilisant un modèle différent renverra une erreur 400. + +## Référence de l'API + +### POST /v1/embeddings + +Génère des embeddings pour le ou les textes fournis. + +**Corps de la requête :** + +| Champ | Type | Requis | Description | +|-------|------|--------|-------------| +| `model` | string | Oui | Identifiant d'un modèle d'embeddings pris en charge, ou le modèle fixé au démarrage si `--embedding-model` est utilisé | +| `input` | string ou list[string] | Oui | Texte(s) à encoder | + +**Réponse :** + +```json +{ + "object": "list", + "data": [ + {"object": "embedding", "index": 0, "embedding": [0.023, -0.982, ...]}, + {"object": "embedding", "index": 1, "embedding": [0.112, -0.543, ...]} + ], + "model": "mlx-community/all-MiniLM-L6-v2-4bit", + "usage": {"prompt_tokens": 12, "total_tokens": 12} +} +``` + +## API Python + +### Utilisation directe sans serveur + +```python +from vllm_mlx.embedding import EmbeddingEngine + +engine = EmbeddingEngine("mlx-community/all-MiniLM-L6-v2-4bit") +engine.load() + +vectors = engine.embed(["Hello world", "How are you?"]) +print(f"Dimensions: {len(vectors[0])}") + +tokens = engine.count_tokens(["Hello world"]) +print(f"Token count: {tokens}") +``` + +## Résolution des problèmes + +### mlx-embeddings non installé + +``` +pip install mlx-embeddings>=0.0.5 +``` + +### Modèle introuvable + +Vérifiez que le nom du modèle correspond à l'un des identifiants autorisés listés ci-dessus, ou lancez le serveur avec `--embedding-model` pour fixer un modèle personnalisé. Vous pouvez télécharger les modèles pris en charge à l'avance : + +```bash +huggingface-cli download mlx-community/all-MiniLM-L6-v2-4bit +``` diff --git a/docs/fr/guides/mcp-tools.md b/docs/fr/guides/mcp-tools.md new file mode 100644 index 000000000..77b97a434 --- /dev/null +++ b/docs/fr/guides/mcp-tools.md @@ -0,0 +1,405 @@ +# MCP & Tool Calling + +vllm-mlx prend en charge le Model Context Protocol (MCP) pour intégrer des outils externes avec des LLM. + +## Fonctionnement du tool calling + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ Tool Calling Flow │ +├─────────────────────────────────────────────────────────────────────┤ +│ │ +│ 1. User Request │ +│ ─────────────────► "List files in /tmp" │ +│ │ +│ 2. LLM Generates Tool Call │ +│ ─────────────────► tool_calls: [{ │ +│ name: "list_directory", │ +│ arguments: {path: "/tmp"} │ +│ }] │ +│ │ +│ 3. App Executes Tool via MCP │ +│ ─────────────────► MCP Server executes list_directory │ +│ Returns: ["file1.txt", "file2.txt"] │ +│ │ +│ 4. Tool Result Sent Back to LLM │ +│ ─────────────────► role: "tool", content: [...] │ +│ │ +│ 5. LLM Generates Final Response │ +│ ─────────────────► "The /tmp directory contains 2 files..." │ +│ │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +## Démarrage rapide + +### 1. Créer la configuration MCP + +Créez `mcp.json` : + +```json +{ + "mcpServers": { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] + } + } +} +``` + +### 2. Démarrer le serveur avec MCP + +```bash +# Mode simple +vllm-mlx serve mlx-community/Qwen3-4B-4bit --mcp-config mcp.json + +# Continuous batching +vllm-mlx serve mlx-community/Qwen3-4B-4bit --mcp-config mcp.json --continuous-batching +``` + +### 3. Vérifier l'état du MCP + +```bash +# Vérifier l'état du MCP +curl http://localhost:8000/v1/mcp/status + +# Lister les outils disponibles +curl http://localhost:8000/v1/mcp/tools +``` + +## Exemple de tool calling + +```python +import json +import httpx + +BASE_URL = "http://localhost:8000" + +# 1. Get available tools +tools_response = httpx.get(f"{BASE_URL}/v1/mcp/tools") +tools = tools_response.json()["tools"] + +# 2. Send request with tools +response = httpx.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "model": "default", + "messages": [{"role": "user", "content": "List files in /tmp"}], + "tools": tools, + "max_tokens": 1024 + } +) + +result = response.json() +message = result["choices"][0]["message"] + +# 3. Check for tool calls +if message.get("tool_calls"): + tool_call = message["tool_calls"][0] + + # 4. Execute tool via MCP + exec_response = httpx.post( + f"{BASE_URL}/v1/mcp/execute", + json={ + "server": "filesystem", + "tool": tool_call["function"]["name"], + "arguments": json.loads(tool_call["function"]["arguments"]) + } + ) + tool_result = exec_response.json() + + # 5. Send result back to LLM + messages = [ + {"role": "user", "content": "List files in /tmp"}, + message, + { + "role": "tool", + "tool_call_id": tool_call["id"], + "content": json.dumps(tool_result["result"]) + } + ] + + final_response = httpx.post( + f"{BASE_URL}/v1/chat/completions", + json={"model": "default", "messages": messages} + ) + print(final_response.json()["choices"][0]["message"]["content"]) +``` + +## Points de terminaison MCP + +| Point de terminaison | Méthode | Description | +|----------------------|---------|-------------| +| `/v1/mcp/status` | GET | Vérifier l'état du MCP | +| `/v1/mcp/tools` | GET | Lister les outils disponibles | +| `/v1/mcp/execute` | POST | Exécuter un outil | + +## Exemples de serveurs MCP + +### Système de fichiers + +```json +{ + "mcpServers": { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] + } + } +} +``` + +### GitHub + +```json +{ + "mcpServers": { + "github": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-github"], + "env": { + "GITHUB_TOKEN": "your-token" + } + } + } +} +``` + +### PostgreSQL + +```json +{ + "mcpServers": { + "postgres": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-postgres"], + "env": { + "DATABASE_URL": "postgresql://user:pass@localhost/db" + } + } + } +} +``` + +### Brave Search + +```json +{ + "mcpServers": { + "brave-search": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-brave-search"], + "env": { + "BRAVE_API_KEY": "your-key" + } + } + } +} +``` + +## Plusieurs serveurs MCP + +```json +{ + "mcpServers": { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] + }, + "github": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-github"], + "env": { + "GITHUB_TOKEN": "your-token" + } + } + } +} +``` + +## Chat MCP interactif + +Pour tester MCP de manière interactive : + +```bash +python examples/mcp_chat.py +``` + +## Formats d'outils pris en charge + +vllm-mlx prend en charge 12 tool call parsers couvrant toutes les grandes familles de modèles. Voir [Tool Calling](tool-calling.md) pour la liste complète des parsers, alias et exemples. + +## Sécurité + +vllm-mlx inclut des mesures de sécurité pour prévenir les attaques par injection de commandes via les serveurs MCP. + +### Liste blanche de commandes + +Seules les commandes de confiance sont autorisées par défaut : + +| Catégorie | Commandes autorisées | +|-----------|----------------------| +| Node.js | `npx`, `npm`, `node` | +| Python | `uvx`, `uv`, `python`, `python3`, `pip`, `pipx` | +| Docker | `docker` | +| Serveurs MCP | `mcp-server-*` (serveurs officiels) | + +### Motifs bloqués + +Les motifs suivants sont bloqués pour prévenir les attaques par injection : + +- Enchaînement de commandes : `;`, `&&`, `||`, `|` +- Substitution de commandes : `` ` ``, `$()` +- Traversée de répertoires : `../` +- Variables d'environnement dangereuses : `LD_PRELOAD`, `PATH`, `PYTHONPATH` + +### Exemple : attaque bloquée + +```json +{ + "mcpServers": { + "malicious": { + "command": "bash", + "args": ["-c", "rm -rf /"] + } + } +} +``` + +Cette configuration sera rejetée : +``` +ValueError: MCP server 'malicious': Command 'bash' is not in the allowed commands whitelist. +``` + +### Mode développement (non sécurisé) + +Pour le développement uniquement, il est possible de contourner la validation de sécurité : + +```json +{ + "mcpServers": { + "custom": { + "command": "my-custom-server", + "skip_security_validation": true + } + } +} +``` + +**AVERTISSEMENT** : N'utilisez jamais `skip_security_validation` en production ! + +### Liste blanche personnalisée + +Pour ajouter des commandes personnalisées à la liste blanche par programmation : + +```python +from vllm_mlx.mcp import MCPCommandValidator, set_validator + +# Add custom commands +validator = MCPCommandValidator( + custom_whitelist={"my-trusted-server", "another-server"} +) +set_validator(validator) +``` + +## Sandboxing de l'exécution des outils + +Au-delà de la validation des commandes, vllm-mlx fournit un sandboxing à l'exécution pour les exécutions d'outils. + +### Fonctionnalités du sandbox + +| Fonctionnalité | Description | +|----------------|-------------| +| Liste blanche d'outils | Autoriser uniquement des outils spécifiques à s'exécuter | +| Liste noire d'outils | Bloquer des outils dangereux spécifiques | +| Validation des arguments | Bloquer les motifs dangereux dans les arguments des outils | +| Limitation du débit | Limiter les appels d'outils par minute | +| Journal d'audit | Suivre toutes les exécutions d'outils | + +### Motifs d'arguments bloqués + +Les arguments des outils sont validés pour détecter les motifs dangereux : + +- Traversée de répertoires : `../` +- Répertoires système : `/etc/`, `/proc/`, `/sys/` +- Accès root : `/root/`, `~root` + +### Détection des outils à haut risque + +Les outils correspondant à ces motifs déclenchent des avertissements de sécurité : + +- `execute`, `run_command`, `shell`, `eval`, `exec`, `system`, `subprocess` + +### Configuration personnalisée du sandbox + +```python +from vllm_mlx.mcp import ToolSandbox, set_sandbox + +# Create sandbox with custom settings +sandbox = ToolSandbox( + # Only allow specific tools (whitelist mode) + allowed_tools={"read_file", "list_directory"}, + + # Block specific tools (blacklist mode) + blocked_tools={"execute_command", "run_shell"}, + + # Rate limit: max 30 calls per minute + max_calls_per_minute=30, + + # Optional audit callback + audit_callback=lambda audit: print(f"Tool: {audit.tool_name}, Success: {audit.success}"), +) +set_sandbox(sandbox) +``` + +### Accès aux journaux d'audit + +```python +from vllm_mlx.mcp import get_sandbox + +sandbox = get_sandbox() + +# Get recent audit entries +entries = sandbox.get_audit_log(limit=50) + +# Filter by tool name +file_ops = sandbox.get_audit_log(tool_filter="file") + +# Get only errors +errors = sandbox.get_audit_log(errors_only=True) + +# Clear audit log +sandbox.clear_audit_log() +``` + +### Expurgation des données sensibles + +Les journaux d'audit expurgent automatiquement les champs sensibles (password, token, secret, key, credential, auth) et tronquent les valeurs volumineuses. + +## Dépannage + +### Le serveur MCP ne se connecte pas + +Vérifiez que la commande du serveur MCP est correcte : +```bash +npx -y @modelcontextprotocol/server-filesystem /tmp +``` + +### L'outil ne s'exécute pas + +Vérifiez que l'outil est disponible : +```bash +curl http://localhost:8000/v1/mcp/tools | jq '.tools[].name' +``` + +### L'appel d'outil n'est pas analysé + +Assurez-vous d'utiliser un modèle qui prend en charge le function calling (Qwen3, Llama-3.2-Instruct). + +### La commande n'est pas dans la liste blanche + +Si vous voyez « Command X is not in the allowed commands whitelist », vous pouvez : +1. Utiliser une commande autorisée (voir la liste blanche ci-dessus) +2. Ajouter la commande à une liste blanche personnalisée +3. Utiliser `skip_security_validation: true` (développement uniquement) diff --git a/docs/fr/guides/moe-top-k.md b/docs/fr/guides/moe-top-k.md new file mode 100644 index 000000000..310c9872a --- /dev/null +++ b/docs/fr/guides/moe-top-k.md @@ -0,0 +1,124 @@ +# MoE top_k override (`--moe-top-k`) + +Réduit le nombre d'experts activés par token dans les modèles Mixture of Experts +comme Qwen3-30B-A3B, en échangeant une légère perte de qualité contre un gain +sensible de débit au décodage. + +> **Statut :** option à activer explicitement. Le comportement par défaut est inchangé. Les chiffres de qualité +> ci-dessous concernent Qwen3-30B-A3B-4bit sur M4 Max 128 Go ; vérifiez sur votre modèle +> avant de déployer en production. + +## Ce que ça fait + +Qwen3-30B-A3B est entraîné avec `top_k=8` : chaque token sélectionne 8 experts parmi 128. +Sur Apple Silicon en décodage batch=1, le produit matriciel des experts (`SwitchGLU`) +représente la plus grande part du calcul de chaque couche, et ce coût évolue +approximativement de façon linéaire avec `top_k`. Abaisser `top_k` à l'inférence a +été démontré (LExI 2025, Lynx 2024) comme préservant l'essentiel de la qualité +entraînée tout en réduisant significativement le temps de décodage. + +`--moe-top-k N` parcourt toutes les couches du modèle chargé et, pour chaque couche +qui possède `.mlp.switch_mlp` (c'est-à-dire un bloc sparse-MoE), définit `top_k = N`. +Les couches denses et les modèles denses ne sont pas modifiés ; le flag n'a aucun effet pour eux. + +## Utilisation + +```bash +# Server +vllm-mlx serve mlx-community/Qwen3-30B-A3B-4bit \ + --continuous-batching \ + --moe-top-k 4 + +# Bench +vllm-mlx bench mlx-community/Qwen3-30B-A3B-4bit --moe-top-k 4 +``` + +Le flag est rejeté si `N` est supérieur au `top_k` d'entraînement du modèle +(il ne peut que diminuer, jamais augmenter). + +## Impact mesuré + +### Débit de décodage (M4 Max 128 Go, batch=1, greedy) + +| top_k | tok/s | vs baseline | +|---:|---:|---:| +| 8 (baseline) | 126.5 | - | +| 6 | 136.1 | +7.6% | +| 5 | 140.3 | +10.9% | +| 4 | 147.3 | +16.5% | + +### Qualité (Qwen3-30B-A3B-4bit, lm-evaluation-harness, MLX backend) + + + +| top_k | MMLU (acc) | GSM8K (exact match) | Delta vs baseline | +|---:|---:|---:|---:| +| 8 | TBD | TBD | - | +| 6 | TBD | TBD | TBD | +| 5 | TBD | TBD | TBD | +| 4 | TBD | TBD | TBD | + +MMLU : 200 échantillons sélectionnés aléatoirement, 0-shot. +GSM8K : 100 échantillons sélectionnés aléatoirement, 0-shot, exact-match strict. + +Ces chiffres sont **indicatifs** ; les suites complètes sont plus grandes et +feraient légèrement varier la précision absolue, mais pas le delta relatif entre +configurations. + +### Parité des sorties greedy + +Avec `top_k=4` sur le checkpoint 4-bit, nous avons observé des **16 premiers tokens +générés identiques** par rapport à la baseline sur toutes les requêtes de test. +Cela suggère que top_k=4 ne modifie pas l'argmax dans les premières étapes de +décodage : le modèle est intrinsèquement robuste à la suppression de la moitié +de ses experts activés. + +À `top_k=3` ou moins, la qualité commencerait à se dégrader visiblement (non mesuré +ici ; déduit du papier LExI). Le flag ne peut donc pas descendre en dessous de 1 +au niveau de la validation de configuration, mais le seuil recommandé pour la +production est `top_k=4`. + +## Quand l'utiliser, quand ne pas l'utiliser + +Utilisez-le quand : +- Vous faites tourner un MoE Qwen3 (ou compatible : Qwen3.5 MoE, Gemma-MoE) et le + débit de décodage en usage single-user est votre goulot d'étranglement. +- Votre cas d'usage tolère une légère dégradation de qualité en échange d'une + amélioration visible de la latence. +- Vous déployez sur du matériel limité par la bande passante mémoire (Apple Silicon + série M) où le gather des experts domine le temps de décodage par étape. + +Ne l'utilisez pas quand : +- Vous servez des modèles denses : le flag n'a aucun effet. +- La précision maximale sur les suites d'évaluation est une exigence. +- Vous exécutez des générations longues en chaîne de pensée (mode "thinking") où + la chute de qualité peut être plus prononcée que ce que suggèrent les scores MMLU 0-shot. + +## Combinaison avec d'autres optimisations + +Ce flag se compose avec la quantification. Sur Qwen3-30B-A3B-4bit, nos mesures +de combinaison sont : + +- 4-bit + top_k=8 : 126.5 tok/s (baseline) +- 4-bit + top_k=4 : 147.3 tok/s (+16.5%) +- 3-bit + top_k=8 : 138.6 tok/s (+9.6%) +- 3-bit + top_k=6 : 147.1 tok/s (+16.3%) . divergence de qualité mesurable +- 3-bit + top_k=4 : 157.3 tok/s (+24%) . **la qualité des sorties s'effondre** (le modèle a répondu à une question différente lors de notre test de fumée) + +3-bit + top_k=4 cumule l'erreur numérique au point où l'argmax n'est plus stable. +Limitez-vous à un seul réglage agressif à la fois : soit 4-bit + top_k=4, soit +3-bit + top_k=6. Les deux donnent approximativement le même tok/s (environ 147) +avec des profils de qualité très différents. + +## Fonctionnement interne + +- Fonction de patch : `vllm_mlx.scheduler.apply_moe_top_k_override(model, k)` +- Appliquée dans `Scheduler.__init__` après le chargement du modèle. +- Tests : `tests/test_moe_top_k.py`. couvre les modèles denses, les architectures + mixtes et les chemins de validation. + +## Références + +- LExI : Layer-Adaptive Active Experts, [arXiv 2509.02753](https://arxiv.org/html/2509.02753) +- Not All Experts are Equal (NAEE), [ACL 2024](https://aclanthology.org/2024.acl-long.334.pdf) +- SwiftLM (`SWIFTLM_TOP_K` env knob prior art), [github.com/SharpAI/SwiftLM](https://github.com/SharpAI/SwiftLM) diff --git a/docs/fr/guides/multimodal.md b/docs/fr/guides/multimodal.md new file mode 100644 index 000000000..a8b168a7e --- /dev/null +++ b/docs/fr/guides/multimodal.md @@ -0,0 +1,315 @@ +# Modèles multimodaux (images et vidéos) + +vllm-mlx prend en charge les VLM pour la compréhension des images et des vidéos. + +## Modèles pris en charge + +- Qwen3-VL (recommandé) +- Qwen2-VL +- Gemma 3 +- LLaVA +- Idefics +- PaliGemma +- Pixtral +- Molmo +- DeepSeek-VL + +## Démarrer un serveur multimodal + +```bash +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit --port 8000 +``` + +Les modèles dont le nom contient « VL », « Vision » ou « mllm » sont automatiquement détectés comme multimodaux. + +## Analyse d'images + +### Via le SDK OpenAI + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# Image depuis une URL +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}} + ] + }], + max_tokens=256 +) +print(response.choices[0].message.content) +``` + +### Images en Base64 + +```python +import base64 + +def encode_image(path): + with open(path, "rb") as f: + return base64.b64encode(f.read()).decode("utf-8") + +base64_image = encode_image("photo.jpg") +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} + ] + }] +) +``` + +### Via curl + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}} + ] + }], + "max_tokens": 256 + }' +``` + +## Analyse de vidéos + +### Via le SDK OpenAI + +```python +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What happens in this video?"}, + {"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}} + ] + }], + max_tokens=512 +) +``` + +### Paramètres vidéo + +Contrôlez l'extraction des images via les paramètres du corps étendu : + +```python +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "Describe this video"}, + {"type": "video_url", "video_url": {"url": "video.mp4"}} + ] + }], + extra_body={ + "video_fps": 2.0, + "video_max_frames": 32 + } +) +``` + +### Via curl + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "Describe this video"}, + {"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}} + ] + }], + "video_fps": 2.0, + "video_max_frames": 16 + }' +``` + +## Formats pris en charge + +### Images + +| Format | Exemple | +|--------|---------| +| URL | `{"type": "image_url", "image_url": {"url": "https://..."}}` | +| Fichier local | `{"type": "image_url", "image_url": {"url": "/path/to/image.jpg"}}` | +| Base64 | `{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}}` | + +### Vidéos + +| Format | Exemple | +|--------|---------| +| URL | `{"type": "video_url", "video_url": {"url": "https://..."}}` | +| Fichier local | `{"type": "video", "video": "/path/to/video.mp4"}` | +| Base64 | `{"type": "video_url", "video_url": {"url": "data:video/mp4;base64,..."}}` | + +## API Python + +```python +from vllm_mlx.models import MLXMultimodalLM + +mllm = MLXMultimodalLM("mlx-community/Qwen3-VL-4B-Instruct-3bit") +mllm.load() + +# Image +description = mllm.describe_image("photo.jpg") + +# Vidéo +description = mllm.describe_video("video.mp4", fps=2.0) + +# Prompt personnalisé +output = mllm.generate( + prompt="Compare these images", + images=["img1.jpg", "img2.jpg"] +) +``` + +## Conseils de performance + +### Images +- Les résolutions plus petites sont traitées plus rapidement (224x224 vs 1920x1080) +- Utilisez la résolution adaptée à votre tâche + +### Vidéos +- Un FPS plus bas accélère le traitement +- Moins d'images signifie moins de mémoire utilisée +- 64 images est le maximum pratique (96 et plus provoque un timeout GPU) + +## Benchmarks + +Testés sur Apple M4 Max avec 128 Go de mémoire unifiée. + +### Qwen3-VL-4B-Instruct-3bit + +| Résolution | Temps | Tokens | Vitesse | Mémoire | +|------------|-------|--------|---------|---------| +| 224x224 | 0.87s | 124 | 143 tok/s | 2.6 Go | +| 448x448 | 1.01s | 107 | 106 tok/s | 3.1 Go | +| 768x768 | 1.42s | 127 | 89 tok/s | 3.4 Go | +| 1024x1024 | 1.85s | 116 | 63 tok/s | 3.6 Go | + +### Qwen3-VL-8B-Instruct-4bit + +| Résolution | Temps | Tokens | Vitesse | Mémoire | +|------------|-------|--------|---------|---------| +| 224x224 | 1.08s | 78 | 73 tok/s | 5.6 Go | +| 448x448 | 1.41s | 70 | 50 tok/s | 6.1 Go | +| 768x768 | 2.06s | 91 | 44 tok/s | 6.5 Go | +| 1024x1024 | 3.02s | 76 | 25 tok/s | 7.6 Go | + +### Gemma 3 4B 4bit + +| Résolution | Temps | Tokens | Vitesse | Mémoire | +|------------|-------|--------|---------|---------| +| 224x224 | 0.95s | 30 | 32 tok/s | 5.2 Go | +| 448x448 | 0.99s | 34 | 34 tok/s | 5.2 Go | +| 768x768 | 0.99s | 32 | 32 tok/s | 5.2 Go | +| 1024x1024 | 0.95s | 28 | 29 tok/s | 5.2 Go | + +### Lancer les benchmarks + +```bash +# Benchmark rapide +vllm-mlx-bench --model mlx-community/Qwen3-VL-4B-Instruct-3bit --quick + +# Benchmark complet avec plus de résolutions +vllm-mlx-bench --model mlx-community/Qwen3-VL-4B-Instruct-3bit + +# Benchmark vidéo +vllm-mlx-bench --model mlx-community/Qwen3-VL-4B-Instruct-3bit --video +``` + +## Cache MLLM + +vllm-mlx inclut un système de prefix cache pour les modèles multimodaux, capable d'accélérer significativement les requêtes répétées utilisant les mêmes images. + +### Fonctionnement + +Lorsque vous envoyez une image au modèle, l'encodeur de vision la traite en embeddings. Ce traitement prend 1 à 2 secondes. Le cache MLLM stocke ces embeddings ainsi que l'état du KV cache, de sorte que les requêtes ultérieures avec la même image contournent entièrement l'encodeur de vision. + +Le cache utilise un hachage basé sur le contenu (similaire à LMCache) pour identifier les images identiques, quelle que soit leur forme de transmission (URL, base64 ou chemin de fichier). + +### Activer le cache + +```bash +# Activer avec les paramètres par défaut (512 Mo maximum) +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit --enable-mllm-cache + +# Avec une limite mémoire personnalisée +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit \ + --enable-mllm-cache \ + --mllm-cache-max-mb 1024 +``` + +### API Python + +```python +from vllm_mlx.mllm_cache import MLLMPrefixCacheManager + +# Créer le gestionnaire de cache +cache = MLLMPrefixCacheManager(max_memory_mb=512) + +# Stocker les embeddings et le KV cache après traitement +cache.store( + images=["photo.jpg"], + prompt="Describe this image", + vision_embeddings=embeddings, + kv_cache=kv_state, + num_tokens=128 +) + +# Récupérer depuis le cache lors des requêtes suivantes +entry, match_len = cache.fetch(images=["photo.jpg"], prompt="Describe this image") +if entry: + # Utiliser les embeddings mis en cache, contourner l'encodeur de vision + embeddings = entry.vision_embeddings + kv_state = entry.kv_cache +``` + +### Statistiques du cache + +```python +stats = cache.get_stats() +print(f"Hit rate: {stats.hit_rate:.1%}") +print(f"Memory used: {stats.memory_used_mb:.1f} MB") +print(f"Tokens saved: {stats.tokens_saved}") +``` + +### Gestion de la mémoire + +Le cache utilise une éviction LRU (Least Recently Used) lorsque la limite mémoire est atteinte. Chaque entrée suit : + +- La taille des embeddings de vision +- La taille du KV cache par couche +- La fréquence d'accès pour l'ordonnancement LRU + +En cas de pression mémoire, les entrées les moins récemment consultées sont évincées en premier. + +## Interface de chat Gradio + +Pour un chat multimodal interactif : + +```bash +vllm-mlx-chat --served-model-name mlx-community/Qwen3-VL-4B-Instruct-3bit +``` + +Prend en charge le glisser-déposer d'images et de vidéos. diff --git a/docs/fr/guides/python-api.md b/docs/fr/guides/python-api.md new file mode 100644 index 000000000..2ee713c89 --- /dev/null +++ b/docs/fr/guides/python-api.md @@ -0,0 +1,182 @@ +# Python API + +API Python directe pour un accès programmatique à vllm-mlx. + +## Modèles de langage + +### Utilisation de base + +```python +from vllm_mlx.models import MLXLanguageModel + +# Load model +model = MLXLanguageModel("mlx-community/Llama-3.2-3B-Instruct-4bit") +model.load() + +# Generate text +output = model.generate("What is the capital of France?", max_tokens=100) +print(output.text) +``` + +### Génération en streaming + +```python +for chunk in model.stream_generate("Tell me a story about a robot"): + print(chunk.text, end="", flush=True) +``` + +### Interface de chat + +```python +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, who are you?"} +] +response = model.chat(messages) +print(response.text) +``` + +### Paramètres de génération + +```python +output = model.generate( + prompt="Write a poem", + max_tokens=256, + temperature=0.7, + top_p=0.9, + stop=["END", "\n\n"] +) +``` + +| Paramètre | Description | Défaut | +|-----------|-------------|--------| +| `max_tokens` | Nombre maximum de tokens à générer | 256 | +| `temperature` | Température d'échantillonnage (0-2) | 0.7 | +| `top_p` | Nucleus sampling | 0.9 | +| `stop` | Séquences d'arrêt | None | + +## Modèles vision-langage + +### Utilisation de base + +```python +from vllm_mlx.models import MLXMultimodalLM + +# Load model +mllm = MLXMultimodalLM("mlx-community/Qwen3-VL-4B-Instruct-3bit") +mllm.load() + +# Describe an image +description = mllm.describe_image("photo.jpg") +print(description) +``` + +### Questions-réponses sur une image + +```python +answer = mllm.answer_about_image("photo.jpg", "What color is the car?") +print(answer) +``` + +### Plusieurs images + +```python +output = mllm.generate( + prompt="Compare these two images", + images=["image1.jpg", "image2.jpg"] +) +print(output.text) +``` + +### Compréhension vidéo + +```python +# From local file +output = mllm.generate( + prompt="What is happening in this video?", + videos=["video.mp4"], + video_fps=2.0, + video_max_frames=16 +) +print(output.text) + +# From URL +output = mllm.generate( + prompt="Describe this video", + videos=["https://example.com/video.mp4"], + video_fps=2.0 +) + +# Convenience method +description = mllm.describe_video("video.mp4", fps=2.0) +``` + +### Paramètres vidéo + +| Paramètre | Description | Défaut | +|-----------|-------------|--------| +| `video_fps` | Images par seconde à extraire | 2.0 | +| `video_max_frames` | Nombre maximum d'images à traiter | 32 | + +## API du moteur + +Pour les cas d'utilisation avancés, utilisez le moteur directement : + +### Moteur simple + +```python +from vllm_mlx.engine import SimpleEngine + +engine = SimpleEngine("mlx-community/Llama-3.2-3B-Instruct-4bit") +await engine.start() + +output = await engine.generate( + prompt="Hello world", + max_tokens=100 +) +print(output.text) + +await engine.stop() +``` + +### Moteur avec batching + +```python +from vllm_mlx.engine import BatchedEngine + +engine = BatchedEngine("mlx-community/Llama-3.2-3B-Instruct-4bit") +await engine.start() + +# Multiple concurrent requests +output = await engine.generate( + prompt="Hello world", + max_tokens=100 +) + +await engine.stop() +``` + +## Format de sortie + +Toutes les méthodes de génération retournent un objet `GenerationOutput` : + +```python +output = model.generate("Hello") + +print(output.text) # Generated text +print(output.prompt_tokens) # Input token count +print(output.completion_tokens) # Output token count +print(output.finish_reason) # "stop" or "length" +``` + +## Gestion des erreurs + +```python +from vllm_mlx.models import MLXLanguageModel + +try: + model = MLXLanguageModel("invalid-model") + model.load() +except Exception as e: + print(f"Failed to load model: {e}") +``` diff --git a/docs/fr/guides/reasoning.md b/docs/fr/guides/reasoning.md new file mode 100644 index 000000000..d37095093 --- /dev/null +++ b/docs/fr/guides/reasoning.md @@ -0,0 +1,267 @@ +# Reasoning Models + +vllm-mlx prend en charge les reasoning models qui affichent leur processus de thinking avant de fournir une réponse. Des modèles comme Qwen3 et DeepSeek-R1 encapsulent leur reasoning dans des balises `...`, et vllm-mlx peut analyser ces balises pour séparer le reasoning de la réponse finale. + +## Pourquoi utiliser le reasoning parsing ? + +Lorsqu'un reasoning model génère une sortie, elle ressemble généralement à ceci : + +``` + +Let me analyze this step by step. +First, I need to consider the constraints. +The answer should be a prime number less than 10. +Checking: 2, 3, 5, 7 are all prime and less than 10. + +The prime numbers less than 10 are: 2, 3, 5, 7. +``` + +Sans reasoning parsing, vous obtenez la sortie brute avec les balises incluses. Avec le reasoning parsing activé, le processus de thinking et la réponse finale sont séparés dans des champs distincts de la réponse de l'API. + +## Démarrage rapide + +### Démarrer le serveur avec un reasoning parser + +```bash +# For Qwen3 models +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 + +# For DeepSeek-R1 models +vllm-mlx serve mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit --reasoning-parser deepseek_r1 +``` + +### Format de réponse de l'API + +Lorsque le reasoning parsing est activé, la réponse de l'API inclut un champ `reasoning` : + +**Réponse sans streaming :** + +```json +{ + "choices": [{ + "message": { + "role": "assistant", + "content": "The prime numbers less than 10 are: 2, 3, 5, 7.", + "reasoning": "Let me analyze this step by step.\nFirst, I need to consider the constraints.\nThe answer should be a prime number less than 10.\nChecking: 2, 3, 5, 7 are all prime and less than 10." + } + }] +} +``` + +**Réponse en streaming :** + +Les fragments sont envoyés séparément pour le reasoning et le contenu. Pendant la phase de reasoning, les fragments ont le champ `reasoning` renseigné. Lorsque le modèle passe à la réponse finale, les fragments ont le champ `content` renseigné : + +```json +{"delta": {"reasoning": "Let me analyze"}} +{"delta": {"reasoning": " this step by step."}} +{"delta": {"reasoning": "\nFirst, I need to"}} +... +{"delta": {"content": "The prime"}} +{"delta": {"content": " numbers less than 10"}} +{"delta": {"content": " are: 2, 3, 5, 7."}} +``` + +## Utilisation avec le SDK OpenAI + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# Non-streaming +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What are the prime numbers less than 10?"}] +) + +message = response.choices[0].message +print("Reasoning:", message.reasoning) # The thinking process +print("Answer:", message.content) # The final answer +``` + +### Streaming avec Reasoning + +```python +reasoning_text = "" +content_text = "" + +stream = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Solve: 2 + 2 = ?"}], + stream=True +) + +for chunk in stream: + delta = chunk.choices[0].delta + if hasattr(delta, 'reasoning') and delta.reasoning: + reasoning_text += delta.reasoning + print(f"[Thinking] {delta.reasoning}", end="") + if delta.content: + content_text += delta.content + print(delta.content, end="") + +print(f"\n\nFinal reasoning: {reasoning_text}") +print(f"Final answer: {content_text}") +``` + +## Parsers disponibles + +### Parser Qwen3 (`qwen3`) + +Pour les modèles Qwen3 qui utilisent explicitement les balises `` et ``. + +- Nécessite **les deux** balises ouvrante et fermante +- Si les balises sont absentes, la sortie est traitée comme du contenu ordinaire +- Recommandé pour : Qwen3-0.6B, Qwen3-4B, Qwen3-8B et les modèles similaires + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 +``` + +### Parser DeepSeek-R1 (`deepseek_r1`) + +Pour les modèles DeepSeek-R1 qui peuvent omettre la balise ouvrante ``. + +- Plus permissif que le parser Qwen3 +- Gère les cas où `` est implicite +- Le contenu avant `` est traité comme du reasoning même en l'absence de `` + +```bash +vllm-mlx serve mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit --reasoning-parser deepseek_r1 +``` + +## Fonctionnement + +Le reasoning parser utilise une détection textuelle pour identifier les balises de thinking dans la sortie du modèle. Pendant le streaming, il suit la position courante dans la sortie afin d'acheminer chaque token vers `reasoning` ou `content`. + +``` +Model Output: Step 1: analyze...The answer is 42. + ├─────────────────────┤├─────────────────────┤ +Parsed: │ reasoning ││ content │ + └─────────────────────┘└─────────────────────┘ +``` + +L'analyse est sans état et s'appuie sur le texte accumulé pour déterminer le contexte, ce qui la rend robuste dans les scénarios de streaming où les tokens peuvent arriver en fragments arbitraires. + +## Conseils pour de meilleurs résultats + +### Rédaction des prompts + +Les reasoning models fonctionnent mieux lorsque vous encouragez une réflexion étape par étape : + +```python +messages = [ + {"role": "system", "content": "Think through problems step by step before answering."}, + {"role": "user", "content": "What is 17 × 23?"} +] +``` + +### Gestion de l'absence de reasoning + +Certains prompts peuvent ne pas déclencher de reasoning. Dans ce cas, `reasoning` vaut `None` et toute la sortie va dans `content` : + +```python +message = response.choices[0].message +if message.reasoning: + print(f"Model's thought process: {message.reasoning}") +print(f"Answer: {message.content}") +``` + +### Température et reasoning + +Les températures basses tendent à produire des schémas de reasoning plus cohérents : + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Explain quantum entanglement"}], + temperature=0.3 # More focused reasoning +) +``` + +## Compatibilité ascendante + +Lorsque `--reasoning-parser` n'est pas spécifié, le serveur se comporte comme avant : +- Les balises de thinking sont incluses dans le champ `content` +- Aucun champ `reasoning` n'est ajouté aux réponses + +Cela garantit que les applications existantes continuent de fonctionner sans modification. + +## Exemple : résolveur de problèmes mathématiques + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +def solve_math(problem: str) -> dict: + """Solve a math problem and return reasoning + answer.""" + response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a math tutor. Show your work."}, + {"role": "user", "content": problem} + ], + temperature=0.2 + ) + + message = response.choices[0].message + return { + "problem": problem, + "work": message.reasoning, + "answer": message.content + } + +result = solve_math("If a train travels 120 km in 2 hours, what is its average speed?") +print(f"Problem: {result['problem']}") +print(f"\nWork shown:\n{result['work']}") +print(f"\nFinal answer: {result['answer']}") +``` + +## Exemples avec curl + +### Sans streaming + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "What is 15% of 80?"}] + }' +``` + +### Avec streaming + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "What is 15% of 80?"}], + "stream": true + }' +``` + +## Résolution de problèmes + +### Champ reasoning absent de la réponse + +- Vérifiez que le serveur a bien été démarré avec `--reasoning-parser` +- Vérifiez que le modèle utilise effectivement des balises de thinking (tous les prompts ne déclenchent pas le reasoning) + +### Le reasoning apparaît dans le contenu + +- Le modèle n'utilise peut-être pas le format de balises attendu +- Essayez un autre parser (`qwen3` ou `deepseek_r1`) + +### Reasoning tronqué + +- Augmentez `--max-tokens` si le modèle atteint la limite de tokens en plein milieu de sa réflexion + +## Voir aussi + +- [Modèles pris en charge](../reference/models.md) - Modèles qui prennent en charge le reasoning +- [Configuration du serveur](server.md) - Toutes les options du serveur +- [Référence CLI](../reference/cli.md) - Options de la ligne de commande diff --git a/docs/fr/guides/server.md b/docs/fr/guides/server.md new file mode 100644 index 000000000..6e5bc6528 --- /dev/null +++ b/docs/fr/guides/server.md @@ -0,0 +1,781 @@ +# Serveur compatible OpenAI + +vllm-mlx fournit un serveur FastAPI avec une compatibilité complète avec l'API OpenAI. + +Par défaut, le serveur n'écoute que sur `127.0.0.1`. Utilisez `--host 0.0.0.0` uniquement si vous souhaitez délibérément l'exposer au-delà de la machine locale. + +## Démarrage du serveur + +### Mode simple (par défaut) + +Débit maximal pour un utilisateur unique : + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 +``` + +### Mode continuous batching + +Pour plusieurs utilisateurs simultanés : + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous-batching +``` + +### Avec paged cache + +Mise en cache efficace en mémoire pour la production : + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous-batching --use-paged-cache +``` + +## Options du serveur + +| Option | Description | Défaut | +|--------|-------------|--------| +| `--port` | Port du serveur | 8000 | +| `--host` | Hôte du serveur | 127.0.0.1 | +| `--api-key` | Clé API pour l'authentification | None | +| `--rate-limit` | Requêtes par minute par client (0 = désactivé) | 0 | +| `--timeout` | Délai d'expiration des requêtes en secondes | 300 | +| `--enable-metrics` | Expose les métriques Prometheus sur `/metrics` | False | +| `--continuous-batching` | Active le batching pour plusieurs utilisateurs | False | +| `--use-paged-cache` | Active le paged KV cache | False | +| `--cache-memory-mb` | Limite mémoire du cache en Mo | Auto | +| `--cache-memory-percent` | Fraction de la RAM réservée au cache | 0.20 | +| `--max-tokens` | Nombre maximal de tokens par défaut | 32768 | +| `--max-request-tokens` | Valeur maximale de `max_tokens` acceptée des clients API | 32768 | +| `--default-temperature` | Température par défaut si non spécifiée | None | +| `--default-top-p` | Valeur top_p par défaut si non spécifiée | None | +| `--stream-interval` | Tokens par fragment de streaming | 1 | +| `--mcp-config` | Chemin vers le fichier de configuration MCP | None | +| `--reasoning-parser` | Parser pour les modèles reasoning (`qwen3`, `deepseek_r1`) | None | +| `--embedding-model` | Précharge un modèle d'embeddings au démarrage | None | +| `--enable-auto-tool-choice` | Active le tool calling automatique | False | +| `--tool-call-parser` | Parser de tool calling (voir [Tool Calling](tool-calling.md)) | None | + +## Points de terminaison de l'API + +### Chat Completions + +```bash +POST /v1/chat/completions +``` + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# Non-streaming +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Hello!"}], + max_tokens=100 +) + +# Streaming +stream = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Tell me a story"}], + stream=True +) +for chunk in stream: + if chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="") +``` + +### Completions + +```bash +POST /v1/completions +``` + +```python +response = client.completions.create( + model="default", + prompt="The capital of France is", + max_tokens=50 +) +``` + +### Modèles + +```bash +GET /v1/models +``` + +Retourne les modèles disponibles. + +### Embeddings + +```bash +POST /v1/embeddings +``` + +```python +response = client.embeddings.create( + model="mlx-community/multilingual-e5-small-mlx", + input="Hello world" +) +print(response.data[0].embedding[:5]) # First 5 dimensions +``` + +Voir le [Guide des embeddings](embeddings.md) pour plus de détails. + +### Vérification de l'état + +```bash +GET /health +``` + +Retourne l'état du serveur. + +### Métriques + +```bash +GET /metrics +``` + +Point de terminaison de collecte Prometheus pour les métriques du serveur, du cache, du scheduler et des requêtes. +Le point de terminaison est désactivé par défaut et s'active avec `--enable-metrics`. + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit \ + --enable-metrics +``` + +`/metrics` est intentionnellement non authentifié. Exposez-le uniquement sur un réseau de confiance ou derrière un reverse proxy ou un pare-feu qui limite les accès. + +### API Anthropic Messages + +```bash +POST /v1/messages +``` + +Point de terminaison compatible Anthropic qui permet à des outils comme Claude Code et OpenCode de se connecter directement à vllm-mlx. En interne, il traduit les requêtes Anthropic au format OpenAI, exécute l'inférence via le moteur, puis convertit la réponse au format Anthropic. + +Fonctionnalités : +- Réponses non-streaming et streaming (SSE) +- Messages système (chaîne simple ou liste de blocs de contenu) +- Conversations multi-tours avec messages utilisateur et assistant +- Tool calling avec blocs de contenu `tool_use` et `tool_result` +- Comptage de tokens pour le suivi du budget +- Contenu multimodal (images via blocs `source`) +- Détection de déconnexion client (retourne HTTP 499) +- Filtrage automatique des tokens spéciaux dans la sortie en streaming + +#### Non-streaming + +```python +from anthropic import Anthropic + +client = Anthropic(base_url="http://localhost:8000", api_key="not-needed") + +response = client.messages.create( + model="default", + max_tokens=256, + messages=[{"role": "user", "content": "Hello!"}] +) +print(response.content[0].text) +# Response includes: response.id, response.model, response.stop_reason, +# response.usage.input_tokens, response.usage.output_tokens +``` + +#### Streaming + +Le streaming suit le protocole d'événements SSE d'Anthropic. Les événements sont émis dans cet ordre : +`message_start` -> `content_block_start` -> `content_block_delta` (répété) -> `content_block_stop` -> `message_delta` -> `message_stop` + +```python +with client.messages.stream( + model="default", + max_tokens=256, + messages=[{"role": "user", "content": "Tell me a story"}] +) as stream: + for text in stream.text_stream: + print(text, end="") +``` + +#### Messages système + +Les messages système peuvent être une chaîne simple ou une liste de blocs de contenu : + +```python +# Plain string +response = client.messages.create( + model="default", + max_tokens=256, + system="You are a helpful coding assistant.", + messages=[{"role": "user", "content": "Write a hello world in Python"}] +) + +# List of content blocks +response = client.messages.create( + model="default", + max_tokens=256, + system=[ + {"type": "text", "text": "You are a helpful assistant."}, + {"type": "text", "text": "Be concise in your answers."}, + ], + messages=[{"role": "user", "content": "What is 2+2?"}] +) +``` + +#### Tool calling + +Définissez les outils avec `name`, `description` et `input_schema`. Le modèle retourne des blocs de contenu `tool_use` lorsqu'il souhaite appeler un outil. Renvoyez les résultats sous forme de blocs `tool_result`. + +```python +# Step 1: Send request with tools +response = client.messages.create( + model="default", + max_tokens=1024, + messages=[{"role": "user", "content": "What's the weather in Paris?"}], + tools=[{ + "name": "get_weather", + "description": "Get weather for a city", + "input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + } + }] +) + +# Step 2: Check if model wants to use tools +for block in response.content: + if block.type == "tool_use": + print(f"Tool: {block.name}, Input: {block.input}, ID: {block.id}") + # response.stop_reason will be "tool_use" + +# Step 3: Send tool result back +response = client.messages.create( + model="default", + max_tokens=1024, + messages=[ + {"role": "user", "content": "What's the weather in Paris?"}, + {"role": "assistant", "content": response.content}, + {"role": "user", "content": [ + { + "type": "tool_result", + "tool_use_id": block.id, + "content": "Sunny, 22C" + } + ]} + ], + tools=[{ + "name": "get_weather", + "description": "Get weather for a city", + "input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + } + }] +) +print(response.content[0].text) # "The weather in Paris is sunny, 22C." +``` + +Modes de sélection d'outil : + +| `tool_choice` | Comportement | +|---------------|--------------| +| `{"type": "auto"}` | Le modèle décide d'appeler ou non des outils (par défaut) | +| `{"type": "any"}` | Le modèle doit appeler au moins un outil | +| `{"type": "tool", "name": "get_weather"}` | Le modèle doit appeler l'outil spécifié | +| `{"type": "none"}` | Le modèle n'appellera aucun outil | + +#### Conversations multi-tours + +```python +messages = [ + {"role": "user", "content": "My name is Alice."}, + {"role": "assistant", "content": "Nice to meet you, Alice!"}, + {"role": "user", "content": "What's my name?"}, +] + +response = client.messages.create( + model="default", + max_tokens=100, + messages=messages +) +``` + +#### Comptage de tokens + +```bash +POST /v1/messages/count_tokens +``` + +Compte les tokens d'entrée d'une requête Anthropic en utilisant le tokenizer du modèle. Utile pour le suivi du budget avant d'envoyer une requête. Comptabilise les tokens des messages système, des messages de conversation, des entrées `tool_use`, du contenu `tool_result` et des définitions d'outils (name, description, input_schema). + +```python +import requests + +resp = requests.post("http://localhost:8000/v1/messages/count_tokens", json={ + "model": "default", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "system": "You are helpful.", + "tools": [{ + "name": "search", + "description": "Search the web", + "input_schema": {"type": "object", "properties": {"q": {"type": "string"}}} + }] +}) +print(resp.json()) # {"input_tokens": 42} +``` + +#### Exemples curl + +Non-streaming : + +```bash +curl http://localhost:8000/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +Streaming : + +```bash +curl http://localhost:8000/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "max_tokens": 256, + "stream": true, + "messages": [{"role": "user", "content": "Tell me a joke"}] + }' +``` + +Comptage de tokens : + +```bash +curl http://localhost:8000/v1/messages/count_tokens \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "Hello!"}] + }' +# {"input_tokens": 12} +``` + +#### Champs de la requête + +| Champ | Type | Requis | Défaut | Description | +|-------|------|--------|--------|-------------| +| `model` | string | oui | - | Nom du modèle (utilisez `"default"` pour le modèle chargé) | +| `messages` | list | oui | - | Messages de conversation avec `role` et `content` | +| `max_tokens` | int | oui | - | Nombre maximal de tokens à générer | +| `system` | string or list | non | null | Invite système (chaîne ou liste de blocs `{"type": "text", "text": "..."}`) | +| `stream` | bool | non | false | Active le streaming SSE | +| `temperature` | float | non | 0.7 | Température d'échantillonnage (0.0 = déterministe, 1.0 = créatif) | +| `top_p` | float | non | 0.9 | Seuil d'échantillonnage nucleus | +| `top_k` | int | non | null | Échantillonnage top-k | +| `stop_sequences` | list | non | null | Séquences qui arrêtent la génération | +| `tools` | list | non | null | Définitions d'outils avec `name`, `description`, `input_schema` | +| `tool_choice` | dict | non | null | Mode de sélection d'outil (`auto`, `any`, `tool`, `none`) | +| `metadata` | dict | non | null | Métadonnées arbitraires (transmises telles quelles, non utilisées par le serveur) | + +#### Format de réponse + +Réponse non-streaming : + +```json +{ + "id": "msg_abc123...", + "type": "message", + "role": "assistant", + "model": "default", + "content": [ + {"type": "text", "text": "Hello! How can I help?"} + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 12, + "output_tokens": 8 + } +} +``` + +Lorsque des outils sont appelés, `content` inclut des blocs `tool_use` et `stop_reason` vaut `"tool_use"` : + +```json +{ + "content": [ + {"type": "text", "text": "Let me check the weather."}, + { + "type": "tool_use", + "id": "call_abc123", + "name": "get_weather", + "input": {"city": "Paris"} + } + ], + "stop_reason": "tool_use" +} +``` + +Raisons d'arrêt : + +| `stop_reason` | Signification | +|---------------|---------------| +| `end_turn` | Le modèle a terminé naturellement | +| `tool_use` | Le modèle souhaite appeler un outil | +| `max_tokens` | La limite `max_tokens` a été atteinte | + +#### Utilisation avec Claude Code + +Pointez Claude Code directement vers votre serveur vllm-mlx : + +```bash +# Start the server +vllm-mlx serve mlx-community/Qwen3-Coder-Next-235B-A22B-4bit \ + --continuous-batching \ + --enable-auto-tool-choice \ + --tool-call-parser hermes + +# In another terminal, configure Claude Code +export ANTHROPIC_BASE_URL=http://localhost:8000 +export ANTHROPIC_API_KEY=not-needed +claude +``` + +### État du serveur + +```bash +GET /v1/status +``` + +Point de terminaison de surveillance en temps réel qui retourne des statistiques globales du serveur et des détails par requête. Utile pour déboguer les performances, suivre l'efficacité du cache et surveiller la mémoire Metal GPU. + +```bash +curl -s http://localhost:8000/v1/status | python -m json.tool +``` + +Exemple de réponse : + +```json +{ + "status": "running", + "model": "mlx-community/Qwen3-8B-4bit", + "uptime_s": 342.5, + "steps_executed": 1247, + "num_running": 1, + "num_waiting": 0, + "total_requests_processed": 15, + "total_prompt_tokens": 28450, + "total_completion_tokens": 3200, + "metal": { + "active_memory_gb": 5.2, + "peak_memory_gb": 8.1, + "cache_memory_gb": 2.3 + }, + "cache": { + "type": "memory_aware_cache", + "entries": 5, + "hit_rate": 0.87, + "memory_mb": 2350 + }, + "requests": [ + { + "request_id": "req_abc123", + "phase": "generation", + "tokens_per_second": 45.2, + "ttft_s": 0.8, + "progress": 0.35, + "cache_hit_type": "prefix", + "cached_tokens": 1200, + "generated_tokens": 85, + "max_tokens": 256 + } + ] +} +``` + +Champs de la réponse : + +| Champ | Description | +|-------|-------------| +| `status` | État du serveur : `running`, `stopped` ou `not_loaded` | +| `model` | Nom du modèle chargé | +| `uptime_s` | Secondes écoulées depuis le démarrage du serveur | +| `steps_executed` | Nombre total d'étapes d'inférence exécutées | +| `num_running` | Nombre de requêtes en cours de génération de tokens | +| `num_waiting` | Nombre de requêtes en attente de prefill | +| `total_requests_processed` | Total des requêtes traitées depuis le démarrage | +| `total_prompt_tokens` | Total des tokens de prompt traités depuis le démarrage | +| `total_completion_tokens` | Total des tokens de complétion générés depuis le démarrage | +| `metal.active_memory_gb` | Mémoire Metal GPU actuellement utilisée (Go) | +| `metal.peak_memory_gb` | Pic d'utilisation de la mémoire Metal GPU (Go) | +| `metal.cache_memory_gb` | Utilisation de la mémoire cache Metal (Go) | +| `cache` | Statistiques du cache (type, entrées, taux de hit, utilisation mémoire) | +| `requests` | Liste des requêtes actives avec détails par requête | + +Champs par requête dans `requests` : + +| Champ | Description | +|-------|-------------| +| `request_id` | Identifiant unique de la requête | +| `phase` | Phase actuelle : `queued`, `prefill` ou `generation` | +| `tokens_per_second` | Débit de génération pour cette requête | +| `ttft_s` | TTFT (secondes) | +| `progress` | Pourcentage de complétion (0.0 à 1.0) | +| `cache_hit_type` | Type de correspondance dans le cache : `exact`, `prefix`, `supersequence`, `lcp` ou `miss` | +| `cached_tokens` | Nombre de tokens servis depuis le cache | +| `generated_tokens` | Tokens générés jusqu'à présent | +| `max_tokens` | Nombre maximal de tokens demandés | + +## Tool Calling + +Activez le tool calling compatible OpenAI avec `--enable-auto-tool-choice` : + +```bash +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice \ + --tool-call-parser mistral +``` + +Utilisez l'option `--tool-call-parser` pour sélectionner le parser adapté à votre modèle : + +| Parser | Modèles | +|--------|---------| +| `auto` | Détection automatique (essaie tous les parsers) | +| `mistral` | Mistral, Devstral | +| `qwen` | Qwen, Qwen3 | +| `llama` | Llama 3.x, 4.x | +| `hermes` | Hermes, NousResearch | +| `deepseek` | DeepSeek V3, R1 | +| `kimi` | Kimi K2, Moonshot | +| `granite` | IBM Granite 3.x, 4.x | +| `nemotron` | NVIDIA Nemotron | +| `xlam` | Salesforce xLAM | +| `functionary` | MeetKai Functionary | +| `glm47` | GLM-4.7, GLM-4.7-Flash | + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What's the weather in Paris?"}], + tools=[{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + } + } + }] +) + +if response.choices[0].message.tool_calls: + for tc in response.choices[0].message.tool_calls: + print(f"{tc.function.name}: {tc.function.arguments}") +``` + +Voir le [Guide du tool calling](tool-calling.md) pour la documentation complète. + +## Modèles reasoning + +Pour les modèles qui exposent leur processus de réflexion (Qwen3, DeepSeek-R1), utilisez `--reasoning-parser` pour séparer le reasoning de la réponse finale : + +```bash +# Qwen3 models +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 + +# DeepSeek-R1 models +vllm-mlx serve mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit --reasoning-parser deepseek_r1 +``` + +La réponse de l'API inclut un champ `reasoning` avec le processus de réflexion du modèle : + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What is 17 × 23?"}] +) + +print(response.choices[0].message.reasoning) # Step-by-step thinking +print(response.choices[0].message.content) # Final answer +``` + +En streaming, les fragments de reasoning arrivent en premier, suivis des fragments de contenu : + +```python +for chunk in stream: + delta = chunk.choices[0].delta + if delta.reasoning: + print(f"[Thinking] {delta.reasoning}") + if delta.content: + print(delta.content, end="") +``` + +Voir le [Guide des modèles reasoning](reasoning.md) pour tous les détails. + +## Sortie structurée (mode JSON) + +Forcez le modèle à retourner du JSON valide en utilisant `response_format` : + +### Mode JSON Object + +Retourne n'importe quel JSON valide : + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "List 3 colors"}], + response_format={"type": "json_object"} +) +# Output: {"colors": ["red", "blue", "green"]} +``` + +### Mode JSON Schema + +Retourne du JSON conforme à un schéma spécifique : + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "List 3 colors"}], + response_format={ + "type": "json_schema", + "json_schema": { + "name": "colors", + "schema": { + "type": "object", + "properties": { + "colors": { + "type": "array", + "items": {"type": "string"} + } + }, + "required": ["colors"] + } + } + } +) +# Output validated against schema +data = json.loads(response.choices[0].message.content) +assert "colors" in data +``` + +### Exemple curl + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "List 3 colors"}], + "response_format": {"type": "json_object"} + }' +``` + +## Exemples curl + +### Chat + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "Hello!"}], + "max_tokens": 100 + }' +``` + +### Streaming + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "Hello!"}], + "stream": true + }' +``` + +## Configuration du streaming + +Contrôlez le comportement du streaming avec `--stream-interval` : + +| Valeur | Comportement | +|--------|--------------| +| `1` (par défaut) | Envoie chaque token immédiatement | +| `2-5` | Regroupe les tokens avant envoi | +| `10+` | Débit maximal, sortie plus fragmentée | + +```bash +# Smooth streaming +vllm-mlx serve model --continuous-batching --stream-interval 1 + +# Batched streaming (better for high-latency networks) +vllm-mlx serve model --continuous-batching --stream-interval 5 +``` + +## Intégration Open WebUI + +```bash +# 1. Start vllm-mlx server +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 + +# 2. Start Open WebUI +docker run -d -p 3000:8080 \ + -e OPENAI_API_BASE_URL=http://host.docker.internal:8000/v1 \ + -e OPENAI_API_KEY=not-needed \ + --name open-webui \ + ghcr.io/open-webui/open-webui:main + +# 3. Open http://localhost:3000 +``` + +## Déploiement en production + +### Avec systemd + +Créez `/etc/systemd/system/vllm-mlx.service` : + +```ini +[Unit] +Description=vLLM-MLX Server +After=network.target + +[Service] +Type=simple +ExecStart=/usr/local/bin/vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching --use-paged-cache --port 8000 +Restart=always + +[Install] +WantedBy=multi-user.target +``` + +```bash +sudo systemctl enable vllm-mlx +sudo systemctl start vllm-mlx +``` + +### Paramètres recommandés + +Pour une production avec 50 utilisateurs simultanés ou plus : + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching \ + --use-paged-cache \ + --api-key your-secret-key \ + --rate-limit 60 \ + --timeout 120 \ + --port 8000 +``` diff --git a/docs/fr/guides/tool-calling.md b/docs/fr/guides/tool-calling.md new file mode 100644 index 000000000..8a7713220 --- /dev/null +++ b/docs/fr/guides/tool-calling.md @@ -0,0 +1,244 @@ +# Tool Calling + +vllm-mlx prend en charge le tool calling compatible OpenAI (function calling) avec un parsing automatique pour de nombreuses familles de modèles populaires. + +## Démarrage rapide + +Activez le tool calling en ajoutant le flag `--enable-auto-tool-choice` au démarrage du serveur : + +```bash +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice \ + --tool-call-parser mistral +``` + +Utilisez ensuite les outils avec l'API OpenAI standard : + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What's the weather in Paris?"}], + tools=[{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"} + }, + "required": ["city"] + } + } + }] +) + +# Check for tool calls +if response.choices[0].message.tool_calls: + for tc in response.choices[0].message.tool_calls: + print(f"Function: {tc.function.name}") + print(f"Arguments: {tc.function.arguments}") +``` + +## Parsers disponibles + +Utilisez `--tool-call-parser` pour sélectionner un tool parser adapté à votre famille de modèles : + +| Parser | Alias | Modèles | Format | +|--------|-------|---------|--------| +| `auto` | | N'importe quel modèle | Détection automatique du format (essaie tous les parsers) | +| `mistral` | | Mistral, Devstral | Tableau JSON `[TOOL_CALLS]` | +| `qwen` | `qwen3` | Qwen, Qwen3 | XML `` ou `[Calling tool:]` | +| `llama` | `llama3`, `llama4` | Llama 3.x, 4.x | Balises `` | +| `hermes` | `nous` | Hermes, NousResearch | JSON `` dans XML | +| `deepseek` | `deepseek_v3`, `deepseek_r1` | DeepSeek V3, R1 | Délimiteurs Unicode | +| `kimi` | `kimi_k2`, `moonshot` | Kimi K2, Moonshot | Tokens `<\|tool_call_begin\|>` | +| `granite` | `granite3` | IBM Granite 3.x, 4.x | `<\|tool_call\|>` ou `` | +| `nemotron` | `nemotron3` | NVIDIA Nemotron | `` | +| `xlam` | | Salesforce xLAM | JSON avec tableau `tool_calls` | +| `functionary` | `meetkai` | MeetKai Functionary | Plusieurs blocs de fonctions | +| `glm47` | `glm4` | GLM-4.7, GLM-4.7-Flash | `` avec XML ``/`` | + +## Exemples par modèle + +### Mistral / Devstral + +```bash +# Devstral Small (optimized for coding and tool use) +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice --tool-call-parser mistral + +# Mistral Instruct +vllm-mlx serve mlx-community/Mistral-7B-Instruct-v0.3-4bit \ + --enable-auto-tool-choice --tool-call-parser mistral +``` + +### Qwen + +```bash +# Qwen3 +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --enable-auto-tool-choice --tool-call-parser qwen +``` + +### Llama + +```bash +# Llama 3.2 +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit \ + --enable-auto-tool-choice --tool-call-parser llama +``` + +### DeepSeek + +```bash +# DeepSeek V3 +vllm-mlx serve mlx-community/DeepSeek-V3-0324-4bit \ + --enable-auto-tool-choice --tool-call-parser deepseek +``` + +### IBM Granite + +```bash +# Granite 4.0 +vllm-mlx serve mlx-community/granite-4.0-tiny-preview-4bit \ + --enable-auto-tool-choice --tool-call-parser granite +``` + +### NVIDIA Nemotron + +```bash +# Nemotron 3 Nano +vllm-mlx serve mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-6Bit \ + --enable-auto-tool-choice --tool-call-parser nemotron +``` + +### GLM-4.7 + +```bash +# GLM-4.7 Flash +vllm-mlx serve lmstudio-community/GLM-4.7-Flash-MLX-8bit \ + --enable-auto-tool-choice --tool-call-parser glm47 +``` + +### Kimi K2 + +```bash +# Kimi K2 +vllm-mlx serve mlx-community/Kimi-K2-Instruct-4bit \ + --enable-auto-tool-choice --tool-call-parser kimi +``` + +### Salesforce xLAM + +```bash +# xLAM +vllm-mlx serve mlx-community/xLAM-2-fc-r-4bit \ + --enable-auto-tool-choice --tool-call-parser xlam +``` + +## Parser automatique + +Si vous n'êtes pas sûr du parser à utiliser, le parser `auto` tente de détecter le format automatiquement : + +```bash +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --enable-auto-tool-choice --tool-call-parser auto +``` + +Le parser automatique essaie les formats dans cet ordre : +1. Mistral (`[TOOL_CALLS]`) +2. Qwen bracket (`[Calling tool:]`) +3. Nemotron (``) +4. Qwen/Hermes XML (`{...}`) +5. Llama (`{...}`) +6. JSON brut + +## Streaming tool calls + +Le tool calling fonctionne avec le streaming. Les informations du tool call sont envoyées lorsque le modèle a terminé de générer : + +```python +stream = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What's 25 * 17?"}], + tools=[{ + "type": "function", + "function": { + "name": "calculator", + "description": "Calculate math expressions", + "parameters": { + "type": "object", + "properties": { + "expression": {"type": "string"} + }, + "required": ["expression"] + } + } + }], + stream=True +) + +for chunk in stream: + if chunk.choices[0].delta.tool_calls: + for tc in chunk.choices[0].delta.tool_calls: + print(f"Tool call: {tc.function.name}({tc.function.arguments})") +``` + +## Gestion des résultats de tool calls + +Après avoir reçu un tool call, exécutez la fonction et renvoyez le résultat : + +```python +import json + +# First request - model decides to call a tool +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What's the weather in Tokyo?"}], + tools=[weather_tool] +) + +# Get the tool call +tool_call = response.choices[0].message.tool_calls[0] +tool_call_id = tool_call.id +function_name = tool_call.function.name +arguments = json.loads(tool_call.function.arguments) + +# Execute the function (your implementation) +result = get_weather(**arguments) # {"temperature": 22, "condition": "sunny"} + +# Send result back to model +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "user", "content": "What's the weather in Tokyo?"}, + {"role": "assistant", "tool_calls": [tool_call]}, + {"role": "tool", "tool_call_id": tool_call_id, "content": json.dumps(result)} + ], + tools=[weather_tool] +) + +print(response.choices[0].message.content) +# "The weather in Tokyo is sunny with a temperature of 22C." +``` + +## Gestion des balises think + +Les modèles qui produisent des balises de reasoning `...` (comme DeepSeek-R1, Qwen3, GLM-4.7) sont gérés automatiquement. Le parser supprime le contenu de la réflexion avant d'extraire les tool calls, de sorte que les balises de reasoning n'interfèrent jamais avec le parsing des tool calls. + +Cela fonctionne même lorsque `` a été injecté dans le prompt (balises think implicites avec uniquement un `` fermant). + +## Référence CLI + +| Option | Description | +|--------|-------------| +| `--enable-auto-tool-choice` | Active le tool calling automatique | +| `--tool-call-parser` | Sélectionne le parser (voir tableau ci-dessus) | + +Voir [Référence CLI](../reference/cli.md) pour toutes les options. diff --git a/docs/fr/guides/warm-prompts.md b/docs/fr/guides/warm-prompts.md new file mode 100644 index 000000000..5efcaf8fc --- /dev/null +++ b/docs/fr/guides/warm-prompts.md @@ -0,0 +1,187 @@ +# Warm Prompts + +Pré-remplissez le prefix cache au démarrage du serveur afin que la **première** requête +envoyée par un agent trouve un cache déjà chaud, sans payer le coût complet du prefill +pour son system prompt de plusieurs kilo-octets. + +## Quand utiliser cette fonctionnalité + +Les charges de travail agent. proxies vers des assistants de code ou de raisonnement, serveurs MCP, +orchestrateurs multi-agents. envoient toujours le même system prompt. Aujourd'hui, +la première requête d'un serveur froid paie le prefill complet pour ce system prompt. +Sur un modèle de plusieurs milliards de paramètres, cela représente plusieurs secondes de +TTFT, précisément au moment où un utilisateur attend la première réponse de son nouvel agent. + +Si vous connaissez les system prompts de vos agents au moment du déploiement, écrivez-les +dans un fichier JSON et pointez `--warm-prompts` dessus. Le serveur exécute une complétion +de chat avec `max_tokens=1` pour chacun au démarrage, l'état KV cache est chargé dans le +prefix cache, et la première vraie requête correspond via strict-prefix. + +Nécessite `--continuous-batching` (le prefix cache y est hébergé). + +## Exemple rapide + +```bash +# Écrivez une seule fois les agents qui vous intéressent +cat > ~/.config/vllm-mlx/agents.json <<'JSON' +[ + [{"role": "system", "content": "You are a code assistant..."}] +] +JSON + +# Pointez le serveur dessus +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --continuous-batching \ + --warm-prompts ~/.config/vllm-mlx/agents.json +``` + +Au démarrage vous verrez : + +``` +[lifespan] Warm-up done (strict-prefix): 1 completed, 0 skipped, + 1431 prompt tokens in 0.2s +``` + +La première vraie requête partageant le system prompt réchauffé atteint le cache +avec `tokens_saved` proche de la longueur du prompt de warm-up. + +## Format de fichier + +Une liste JSON de premier niveau. Chaque entrée est elle-même une liste de messages de chat, +de même forme que `messages` dans `/v1/chat/completions`. + +```json +[ + [ + {"role": "system", "content": "You are a code assistant..."} + ], + [ + {"role": "system", "content": "You are a senior code reviewer..."} + ], + [ + {"role": "system", "content": "You are a planner..."}, + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "Hello, what are we planning?"} + ] +] +``` + +Les system prompts à message unique sont le cas le plus courant. Les historiques multi-tours +sont pris en charge pour les scénarios où vous souhaitez réchauffer un début de conversation +spécifique (exemples few-shot, persona d'assistant récurrente). + +## Dimensionnement + +Les warm prompts sont traités **en parallèle** via `asyncio.gather`, donc N entrées +déclenchent N prefills simultanés au démarrage. Chaque prefill alloue du KV cache +pour la longueur de son prompt. + +**Recommandation : 1 à 3 entrées.** Cela couvre les chemins critiques des déploiements +agent typiques (une persona par entrée). Un fichier warm-prompts très grand sur un modèle +à mémoire limitée peut épuiser la marge disponible au démarrage. + +Si vous devez réchauffer des dizaines de personas, ouvrez une issue en décrivant votre +charge de travail et nous pourrons ajouter un paramètre `--warm-prompts-concurrency=N`. + +## Benchmarks + +**Configuration.** M4 Max, 128 Go de mémoire unifiée. Deux serveurs séparés par mesure +(froid et chaud), démarrage à froid isolé. Jeu de prompts `long` (~2 500 tokens utilisateur) +précédé d'un system prompt d'environ 1 700 tokens correspondant au prompt de warm-up. +`max_tokens=128`. bench-serve avec `--skip-preflight-token-count` afin que le preflight +`count_prompt_tokens` ne pollue pas le cache. + +| Modèle | conc | TTFT froid | TTFT chaud | Accélération | +|--------|-----:|-----------:|-----------:|-------------:| +| Qwen3-0.6B-8bit | 1 | 563 ms | 419 ms | 1.34x | +| Qwen3-0.6B-8bit | 4 | 1 723 ms | 1 282 ms | 1.34x | +| Qwen3-0.6B-8bit | 8 | 3 708 ms | 2 661 ms | 1.39x | +| Llama-3.2-3B-Instruct-4bit | 1 | 1 754 ms | 1 060 ms | 1.65x | +| Llama-3.2-3B-Instruct-4bit | 4 | 5 926 ms | 3 945 ms | 1.50x | +| Llama-3.2-3B-Instruct-4bit | 8 | 15 161 ms | 9 820 ms | 1.54x | +| Qwen3-4B-4bit | 1 | 4 937 ms | 2 191 ms | 2.25x | +| Qwen3-4B-4bit | 4 | 12 535 ms | 9 623 ms | 1.30x | +| Qwen3-4B-4bit | 8 | 38 148 ms | 23 878 ms | 1.60x | +| Qwen3.6-35B-A3B-4bit (MoE/hybrid) | 1 | 2 400 ms | 1 603 ms | 1.50x | +| Qwen3.6-35B-A3B-4bit | 4 | 8 735 ms | 6 054 ms | 1.44x | +| Qwen3.6-35B-A3B-4bit | 8 | 22 419 ms | 14 409 ms | 1.56x | + +Les 12 configurations s'améliorent toutes. Les gains de TTFT sont les plus importants +quand le ratio prompt/total est le plus élevé (conc=1, long system prompt) et restent +significatifs sous charge concurrente. + +**La génération tok/s** est neutre (dans ±5 %) pour les modèles denses. +Qwen3.6-35B-A3B (MoE) affiche une baisse de décodage de 20 à 35 % à conc >= 4, qui semble +liée à l'interaction du routage MoE avec la planification en batch. Les gains de TTFT +dominent néanmoins la latence de bout en bout sur les charges agent, mais tenez-en compte +si votre flux de travail est fortement limité par le décodage à forte concurrence. + +## Fonctionnement interne + +Le warm-up naïf. rendu du template de chat avec un message utilisateur fictif et mise en +cache des tokens. ne fonctionne pas pour les modèles hybrides SSM+attention +(Qwen3.5-MoE, Qwen3.6-MoE). Leurs couches de cache incluent un état SSM qui ne peut pas +être tronqué, si bien que `memory_cache.py` désactive la correspondance LCP. Le contenu +utilisateur fictif diverge du vrai contenu utilisateur, et une entrée mise en cache au +niveau des tokens n'est plus un strict prefix d'aucune vraie requête. + +Le warmer ici rend le template de chat **deux fois** avec deux contenus utilisateur +distincts (`"__PROBE_A__"` et `"__PROBE_B__"`), trouve la position de caractère où les deux +chaînes divergent, puis tronque le premier rendu à cette frontière. Cette chaîne tronquée. +tout ce qui précède l'insertion du contenu utilisateur. est ce qui est envoyé au moteur. + +Comme le chemin de vraie requête du moteur rend aussi le template avec `tokenize=False` puis +laisse le tokenizer encoder le résultat, les tokens du warm-up sont garantis d'être un strict +prefix de toute vraie requête avec un system prompt correspondant et un historique de chat +vide. Les correspondances strict-prefix fonctionnent sur tous les types de couche de cache, +y compris les chemins hybrides où LCP est désactivé. + +## Administration + +### Vider le prefix cache en mémoire + +```bash +curl -X DELETE http://localhost:8000/v1/cache/prefix +``` + +Si le serveur a été démarré avec `--warm-prompts`, le warm-up se relance en arrière-plan +après la suppression. La réponse est retournée immédiatement sans attendre la fin du +re-warm. + +Réponse : + +```json +{"status": "cleared", "rewarm_scheduled": true} +``` + +### Inspecter l'état du cache + +```bash +curl http://localhost:8000/v1/status | jq '.cache' +``` + +Après le démarrage avec warm-prompts, vous verrez `entry_count > 0` avant la première +requête utilisateur. + +## Mesurer l'impact sur votre configuration + +Pour mesurer l'impact sur votre modèle et vos prompts, utilisez `bench-serve` : + +```bash +# Froid : sans warm-prompts +vllm-mlx serve MODEL --continuous-batching & +vllm-mlx bench-serve --prompts long --concurrency 1,4 \ + --system-prompt-file my-system.txt --tag cold \ + --output cold.csv --format csv + +# Chaud : même configuration + --warm-prompts +vllm-mlx serve MODEL --continuous-batching \ + --warm-prompts ~/.config/vllm-mlx/agents.json & +vllm-mlx bench-serve --prompts long --concurrency 1,4 \ + --system-prompt-file my-system.txt --tag warm \ + --output warm.csv --format csv +``` + +`--skip-preflight-token-count` est activé automatiquement quand +`--system-prompt-file` est fourni, afin que le preflight `count_prompt_tokens` +ne pollue pas le cache. Comparez `cold.csv` et `warm.csv` pour votre charge de travail. diff --git a/docs/fr/index.md b/docs/fr/index.md new file mode 100644 index 000000000..5149f8a47 --- /dev/null +++ b/docs/fr/index.md @@ -0,0 +1,71 @@ +# Documentation vLLM-MLX + +**Backend MLX pour Apple Silicon sur vLLM** - Accélération GPU pour le texte, les images, la vidéo et l'audio sur Mac + +## Qu'est-ce que vLLM-MLX ? + +vllm-mlx apporte l'accélération GPU native d'Apple Silicon à vLLM en intégrant : + +- **[MLX](https://github.com/ml-explore/mlx)** : le framework ML d'Apple avec mémoire unifiée et noyaux Metal +- **[mlx-lm](https://github.com/ml-explore/mlx-lm)** : inférence LLM optimisée avec KV cache et quantification +- **[mlx-vlm](https://github.com/Blaizzy/mlx-vlm)** : modèles vision-langage (VLM) pour l'inférence multimodale +- **[mlx-audio](https://github.com/Blaizzy/mlx-audio)** : Text-to-Speech et Speech-to-Text avec des voix natives +- **[mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings)** : embeddings textuels pour la recherche sémantique et le RAG + +## Fonctionnalités principales + +- **Multimodal** - Texte, image, vidéo et audio sur une seule plateforme +- **Accélération GPU native** sur Apple Silicon (M1, M2, M3, M4) +- **Voix TTS natives** - espagnol, français, chinois, japonais et 5 autres langues +- **Compatible API OpenAI** - remplacement direct du client OpenAI +- **Embeddings** - point de terminaison `/v1/embeddings` compatible OpenAI +- **MCP Tool Calling** - intégration d'outils externes via le Model Context Protocol +- **Paged KV Cache** - mise en cache efficace en mémoire avec partage de préfixe +- **Continuous Batching** - débit élevé pour plusieurs utilisateurs simultanés + +## Liens rapides + +### Démarrage + +- [Installation](getting-started/installation.md) +- [Démarrage rapide](getting-started/quickstart.md) + +### Guides utilisateur + +- [Serveur compatible OpenAI](guides/server.md) +- [API Python](guides/python-api.md) +- [Multimodal (images et vidéo)](guides/multimodal.md) +- [Audio (STT/TTS)](guides/audio.md) +- [Embeddings](guides/embeddings.md) +- [Modèles de reasoning](guides/reasoning.md) +- [Tool Calling](guides/tool-calling.md) +- [MCP et Tool Calling](guides/mcp-tools.md) +- [Continuous Batching](guides/continuous-batching.md) + +### Référence + +- [Commandes CLI](reference/cli.md) +- [Modèles pris en charge](reference/models.md) +- [Configuration](reference/configuration.md) + +### Benchmarks + +- [Benchmarks LLM](benchmarks/llm.md) +- [Benchmarks image](benchmarks/image.md) +- [Benchmarks vidéo](benchmarks/video.md) +- [Benchmarks audio](benchmarks/audio.md) + +### Développement + +- [Architecture](../development/architecture.md) +- [Contribuer](../development/contributing.md) + +## Prérequis + +- macOS sur Apple Silicon (M1/M2/M3/M4) +- Python 3.10+ +- 8 Go de RAM recommandés + +## Licence + +Apache 2.0 - Consultez [LICENSE](../../LICENSE) pour les détails. diff --git a/docs/fr/reference/cli.md b/docs/fr/reference/cli.md new file mode 100644 index 000000000..8ee0184bf --- /dev/null +++ b/docs/fr/reference/cli.md @@ -0,0 +1,210 @@ +# Référence CLI + +## Vue d'ensemble des commandes + +| Commande | Description | +|---------|-------------| +| `vllm-mlx serve` | Démarrer le serveur compatible OpenAI | +| `vllm-mlx-bench` | Exécuter des benchmarks de performance | +| `vllm-mlx-chat` | Démarrer l'interface de chat Gradio | + +## `vllm-mlx serve` + +Démarrer le serveur API compatible OpenAI. + +### Utilisation + +```bash +vllm-mlx serve [options] +``` + +### Options + +| Option | Description | Défaut | +|--------|-------------|---------| +| `--served-model-name` | Nom de modèle personnalisé exposé via l'API OpenAI. Si non défini, le chemin du modèle est utilisé comme nom. | None | +| `--port` | Port du serveur | 8000 | +| `--host` | Hôte du serveur | 127.0.0.1 | +| `--api-key` | Clé API pour l'authentification | None | +| `--rate-limit` | Requêtes par minute par client (0 = désactivé) | 0 | +| `--timeout` | Délai d'attente des requêtes en secondes | 300 | +| `--enable-metrics` | Exposer les métriques Prometheus sur `/metrics` | False | +| `--continuous-batching` | Activer le continuous batching pour plusieurs utilisateurs | False | +| `--cache-memory-mb` | Limite mémoire du cache en Mo | Auto | +| `--cache-memory-percent` | Fraction de la RAM réservée au cache | 0.20 | +| `--no-memory-aware-cache` | Utiliser le cache legacy basé sur le nombre d'entrées | False | +| `--use-paged-cache` | Activer le KV cache paginé | False | +| `--max-tokens` | Nombre maximum de tokens par défaut | 32768 | +| `--max-request-tokens` | Valeur maximale de `max_tokens` acceptée depuis les clients API | 32768 | +| `--stream-interval` | Tokens par fragment de streaming | 1 | +| `--mcp-config` | Chemin vers le fichier de configuration MCP | None | +| `--paged-cache-block-size` | Tokens par bloc de cache | 64 | +| `--max-cache-blocks` | Nombre maximum de blocs de cache | 1000 | +| `--max-num-seqs` | Nombre maximum de séquences simultanées | 256 | +| `--default-temperature` | Température par défaut si non spécifiée dans la requête | None | +| `--default-top-p` | Valeur top_p par défaut si non spécifiée dans la requête | None | +| `--max-audio-upload-mb` | Taille maximale des fichiers audio téléversés pour `/v1/audio/transcriptions` | 25 | +| `--max-tts-input-chars` | Longueur maximale du texte acceptée par `/v1/audio/speech` | 4096 | +| `--reasoning-parser` | Analyseur pour les modèles de reasoning (`qwen3`, `deepseek_r1`) | None | +| `--embedding-model` | Pré-charger un modèle d'embeddings au démarrage | None | +| `--enable-auto-tool-choice` | Activer le tool calling automatique | False | +| `--tool-call-parser` | Analyseur d'appels d'outils (`auto`, `mistral`, `qwen`, `llama`, `hermes`, `deepseek`, `kimi`, `granite`, `nemotron`, `xlam`, `functionary`, `glm47`) | None | + +### Exemples + +```bash +# Simple mode (single user, max throughput) +# Model path is used as the model name in the OpenAI API (e.g. model="mlx-community/Llama-3.2-3B-Instruct-4bit") +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit + +Model will show up as 'mlx-community/Llama-3.2-3B-Instruct-4bit' in the `/v1/models` API endpoint. View with `curl http://localhost:8000/v1/models` or similar. + +# With a custom API model name (model is accessed as "my-model" via the OpenAI API) +# --served-model-name sets the name clients must use when calling the API (e.g. model="my-model") +vllm-mlx serve --served-model-name my-model mlx-community/Llama-3.2-3B-Instruct-4bit +# Note: Model will show up as 'my-model' in the `/v1/models` API endpoint. + +# Continuous batching (multiple users) +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --continuous-batching + +# With memory limit for large models +vllm-mlx serve mlx-community/GLM-4.7-Flash-4bit \ + --continuous-batching \ + --cache-memory-mb 2048 + +# Production with paged cache +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching \ + --use-paged-cache \ + --port 8000 + +# With MCP tools +vllm-mlx serve mlx-community/Qwen3-4B-4bit --mcp-config mcp.json + +# Multimodal model +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit + +# Reasoning model (separates thinking from answer) +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 + +# DeepSeek reasoning model +vllm-mlx serve mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit --reasoning-parser deepseek_r1 + +# Tool calling with Mistral/Devstral +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice --tool-call-parser mistral + +# Tool calling with Granite +vllm-mlx serve mlx-community/granite-4.0-tiny-preview-4bit \ + --enable-auto-tool-choice --tool-call-parser granite + +# With API key authentication +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --api-key your-secret-key + +# Expose Prometheus metrics +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --enable-metrics + +# Production setup with security options +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --api-key your-secret-key \ + --rate-limit 60 \ + --timeout 120 \ + --continuous-batching +``` + +### Sécurité + +Lorsque `--api-key` est défini, toutes les requêtes API requièrent l'en-tête `Authorization: Bearer ` : + +```python +from openai import OpenAI + +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="your-secret-key" # Must match --api-key +) +``` + +Ou avec curl : + +```bash +curl http://localhost:8000/v1/models \ + -H "Authorization: Bearer your-secret-key" +``` + +## `vllm-mlx-bench` + +Exécuter des benchmarks de performance. + +### Utilisation + +```bash +vllm-mlx-bench --model [options] +``` + +### Options + +| Option | Description | Défaut | +|--------|-------------|---------| +| `--model` | Nom du modèle | Requis | +| `--prompts` | Nombre de prompts | 5 | +| `--max-tokens` | Nombre maximum de tokens par prompt | 256 | +| `--quick` | Mode benchmark rapide | False | +| `--video` | Exécuter le benchmark vidéo | False | +| `--video-url` | URL vidéo personnalisée | None | +| `--video-path` | Chemin vidéo personnalisé | None | + +### Exemples + +```bash +# LLM benchmark +vllm-mlx-bench --model mlx-community/Llama-3.2-1B-Instruct-4bit + +# Quick benchmark +vllm-mlx-bench --model mlx-community/Llama-3.2-1B-Instruct-4bit --quick + +# Image benchmark (auto-detected for VLM models) +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit + +# Video benchmark +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video + +# Custom video +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit \ + --video --video-url https://example.com/video.mp4 +``` + +## `vllm-mlx-chat` + +Démarrer l'interface de chat Gradio. + +### Utilisation + +```bash +vllm-mlx-chat --served-model-name [options] +``` + +### Options + +| Option | Description | Défaut | +|--------|-------------|---------| +| `--model` | Nom du modèle | Requis | +| `--port` | Port Gradio | 7860 | +| `--text-only` | Désactiver le multimodal | False | + +### Exemples + +```bash +# Multimodal chat (text + images + video) +vllm-mlx-chat --served-model-name mlx-community/Qwen3-VL-4B-Instruct-3bit + +# Text-only chat +vllm-mlx-chat --served-model-name mlx-community/Llama-3.2-3B-Instruct-4bit --text-only +``` + +## Variables d'environnement + +| Variable | Description | +|----------|-------------| +| `VLLM_MLX_TEST_MODEL` | Modèle utilisé pour les tests | +| `HF_TOKEN` | Jeton HuggingFace | diff --git a/docs/fr/reference/configuration.md b/docs/fr/reference/configuration.md new file mode 100644 index 000000000..632cdf9e8 --- /dev/null +++ b/docs/fr/reference/configuration.md @@ -0,0 +1,189 @@ +# Référence de configuration + +## Configuration du serveur + +### Options de base + +| Option | Description | Défaut | +|--------|-------------|---------| +| `--host` | Adresse hôte du serveur | `127.0.0.1` | +| `--port` | Port du serveur | `8000` | +| `--max-tokens` | Nombre maximum de tokens par défaut | `32768` | +| `--max-request-tokens` | Valeur maximale de `max_tokens` acceptée depuis les clients API | `32768` | +| `--default-temperature` | Température par défaut si non spécifiée dans la requête | None | +| `--default-top-p` | top_p par défaut si non spécifié dans la requête | None | + +### Options de sécurité + +| Option | Description | Défaut | +|--------|-------------|---------| +| `--api-key` | Clé API pour l'authentification | None | +| `--rate-limit` | Requêtes par minute par client (0 = désactivé) | `0` | +| `--timeout` | Délai d'expiration des requêtes en secondes | `300` | +| `--enable-metrics` | Expose les métriques Prometheus sur `/metrics` | `false` | +| `--max-audio-upload-mb` | Taille maximale du fichier audio téléversé pour `/v1/audio/transcriptions` | `25` | +| `--max-tts-input-chars` | Longueur maximale du texte acceptée par `/v1/audio/speech` | `4096` | + +### Options de batching + +| Option | Description | Défaut | +|--------|-------------|---------| +| `--continuous-batching` | Active le continuous batching | `false` | +| `--stream-interval` | Tokens par fragment de streaming | `1` | +| `--max-num-seqs` | Nombre maximum de séquences simultanées | `256` | + +### Options de cache + +| Option | Description | Défaut | +|--------|-------------|---------| +| `--cache-memory-mb` | Limite mémoire du cache en Mo | Auto | +| `--cache-memory-percent` | Fraction de la RAM allouée au cache | `0.20` | +| `--no-memory-aware-cache` | Utilise le cache legacy basé sur le nombre d'entrées | `false` | +| `--use-paged-cache` | Active le KV cache paginé | `false` | +| `--paged-cache-block-size` | Tokens par bloc | `64` | +| `--max-cache-blocks` | Nombre maximum de blocs | `1000` | + +### Options d'appel d'outils + +| Option | Description | Défaut | +|--------|-------------|---------| +| `--enable-auto-tool-choice` | Active l'appel automatique d'outils | `false` | +| `--tool-call-parser` | Parseur d'appels d'outils (voir [Appel d'outils](../guides/tool-calling.md)) | None | + +### Options de raisonnement + +| Option | Description | Défaut | +|--------|-------------|---------| +| `--reasoning-parser` | Parseur pour les modèles de raisonnement (`qwen3`, `deepseek_r1`) | None | + +### Options d'embeddings + +| Option | Description | Défaut | +|--------|-------------|---------| +| `--embedding-model` | Précharge un modèle d'embedding au démarrage | None | + +### Options MCP + +| Option | Description | Défaut | +|--------|-------------|---------| +| `--mcp-config` | Chemin vers le fichier de configuration MCP | None | + +## Configuration MCP + +Créez le fichier `mcp.json` : + +```json +{ + "mcpServers": { + "server-name": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-name", "arg1"], + "env": { + "ENV_VAR": "value" + } + } + } +} +``` + +### Options du serveur MCP + +| Champ | Description | Obligatoire | +|-------|-------------|-------------| +| `command` | Commande exécutable | Oui | +| `args` | Arguments de la commande | Oui | +| `env` | Variables d'environnement | Non | + +## Options des requêtes API + +### Complétions de chat + +| Paramètre | Description | Défaut | +|-----------|-------------|---------| +| `model` | Nom du modèle | Obligatoire | +| `messages` | Messages du chat | Obligatoire | +| `max_tokens` | Nombre maximum de tokens à générer | 256 | +| `temperature` | Température d'échantillonnage | Défaut du modèle | +| `top_p` | Échantillonnage par noyau | Défaut du modèle | +| `stream` | Active le streaming | `true` | +| `stop` | Séquences d'arrêt | None | +| `tools` | Définitions des outils | None | +| `response_format` | Format de sortie (`json_object`, `json_schema`) | None | + +### Options multimodales + +| Paramètre | Description | Défaut | +|-----------|-------------|---------| +| `video_fps` | Images par seconde | 2.0 | +| `video_max_frames` | Nombre maximum d'images | 32 | + +## Variables d'environnement + +| Variable | Description | +|----------|-------------| +| `VLLM_MLX_TEST_MODEL` | Modèle par défaut pour les tests | +| `HF_TOKEN` | Token d'authentification HuggingFace | +| `OPENAI_API_KEY` | À définir avec n'importe quelle valeur pour la compatibilité SDK | + +## Exemples de configurations + +### Développement (utilisateur unique) + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit +``` + +### Production (utilisateurs multiples) + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching \ + --use-paged-cache \ + --api-key your-secret-key \ + --rate-limit 60 \ + --port 8000 +``` + +### Avec appel d'outils + +```bash +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice \ + --tool-call-parser mistral \ + --continuous-batching +``` + +### Avec les outils MCP + +```bash +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --mcp-config mcp.json \ + --enable-auto-tool-choice \ + --tool-call-parser qwen \ + --continuous-batching +``` + +### Modèle de raisonnement + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit \ + --reasoning-parser qwen3 \ + --continuous-batching +``` + +### Avec embeddings + +```bash +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --embedding-model mlx-community/multilingual-e5-small-mlx \ + --continuous-batching +``` + +### Débit élevé + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching \ + --stream-interval 5 \ + --max-num-seqs 256 +``` diff --git a/docs/fr/reference/models.md b/docs/fr/reference/models.md new file mode 100644 index 000000000..fc2a1ef93 --- /dev/null +++ b/docs/fr/reference/models.md @@ -0,0 +1,99 @@ +# Modèles pris en charge + +Tous les modèles quantifiés de [mlx-community sur HuggingFace](https://huggingface.co/mlx-community/models) sont compatibles. + +Parcourez des milliers de modèles pré-optimisés à l'adresse : **https://huggingface.co/mlx-community/models** + +## Modèles de langage (via mlx-lm) + +| Famille de modèles | Tailles | Quantification | +|--------------------|---------|----------------| +| Llama 3.x, 4.x | 1B, 3B, 8B, 70B | 4-bit | +| Mistral / Devstral | 7B, Mixtral 8x7B | 4-bit, 8-bit | +| Qwen2/Qwen3 | 0.5B à 72B | Variable | +| DeepSeek V3, R1 | 7B, 33B, 67B | 4-bit | +| Gemma 2, 3, 4 | 2B, 9B, 27B | 4-bit | +| GLM-4.7 | Flash, Base | 4-bit, 8-bit | +| Kimi K2 | Variable | 4-bit | +| Phi-3 | 3.8B, 14B | 4-bit | +| Granite 3.x, 4.x | Variable | 4-bit | +| Nemotron | 3 Nano 30B | 6-bit | + +### Modèles recommandés + +| Cas d'utilisation | Modèle | Mémoire | +|-------------------|--------|---------| +| Rapide / léger | `mlx-community/Qwen3-0.6B-8bit` | ~0,7 Go | +| Équilibré | `mlx-community/Llama-3.2-3B-Instruct-4bit` | ~1,8 Go | +| Qualité | `mlx-community/Llama-3.1-8B-Instruct-4bit` | ~4,5 Go | +| Grand modèle | `mlx-community/Qwen3-30B-A3B-4bit` | ~16 Go | + +## Modèles multimodaux (via mlx-vlm) + +| Famille de modèles | Exemples de modèles | +|--------------------|---------------------| +| **Qwen-VL** | `Qwen3-VL-4B-Instruct-3bit`, `Qwen3-VL-8B-Instruct-4bit`, `Qwen2-VL-2B/7B-Instruct-4bit` | +| **LLaVA** | `llava-1.5-7b-4bit`, `llava-v1.6-mistral-7b-4bit`, `llava-llama-3-8b-v1_1-4bit` | +| **Idefics** | `Idefics3-8B-Llama3-4bit`, `idefics2-8b-4bit` | +| **Gemma 4** | `gemma-4-e2b-it-mxfp4` (vision + audio) | +| **PaliGemma** | `paligemma2-3b-mix-224-4bit`, `paligemma-3b-mix-224-8bit` | +| **Pixtral** | `pixtral-12b-4bit`, `pixtral-12b-8bit` | +| **Molmo** | `Molmo-7B-D-0924-4bit`, `Molmo-7B-D-0924-8bit` | +| **Phi-3 Vision** | `Phi-3-vision-128k-instruct-4bit` | +| **DeepSeek-VL** | `deepseek-vl-7b-chat-4bit`, `deepseek-vl2-small-4bit` | + +### Modèles VLM recommandés + +| Cas d'utilisation | Modèle | Mémoire | +|-------------------|--------|---------| +| Rapide / léger | `mlx-community/Qwen3-VL-4B-Instruct-3bit` | ~3 Go | +| Équilibré | `mlx-community/Qwen3-VL-8B-Instruct-4bit` | ~6 Go | +| Qualité | `mlx-community/Qwen3-VL-30B-A3B-Instruct-6bit` | ~20 Go | + +## Modèles d'embeddings (via mlx-embeddings) + +| Famille de modèles | Exemples de modèles | +|--------------------|---------------------| +| **BERT** | `mlx-community/bert-base-uncased-mlx` | +| **XLM-RoBERTa** | `mlx-community/multilingual-e5-small-mlx`, `mlx-community/multilingual-e5-large-mlx` | +| **ModernBERT** | `mlx-community/ModernBERT-base-mlx` | + +## Modèles audio (via mlx-audio) + +| Type | Famille de modèles | Exemples de modèles | +|------|--------------------|---------------------| +| **STT** | Whisper | `mlx-community/whisper-large-v3-turbo` | +| **STT** | Parakeet | `mlx-community/parakeet-tdt-0.6b-v2` | +| **TTS** | Kokoro | `prince-canuma/Kokoro-82M` | +| **TTS** | Chatterbox | `chatterbox/chatterbox-tts-0.1` | + +## Détection automatique des modèles + +vllm-mlx détecte automatiquement les modèles multimodaux selon des motifs dans leur nom : +- Contient "VL", "Vision", "vision" +- Contient "llava", "idefics", "paligemma" +- Contient "pixtral", "molmo", "deepseek-vl" +- Contient "MedGemma", "Gemma-3", "Gemma-4" (variantes multimodales) + +## Utilisation des modèles + +### Depuis HuggingFace + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit +``` + +### Chemin local + +```bash +vllm-mlx serve /path/to/local/model +``` + +## Recherche de modèles + +Filtrez les modèles mlx-community par : +- **LLM** : `Llama`, `Qwen`, `Mistral`, `Phi`, `Gemma`, `DeepSeek`, `GLM`, `Kimi`, `Granite`, `Nemotron` +- **VLM** : `-VL-`, `llava`, `paligemma`, `pixtral`, `molmo`, `idefics`, `deepseek-vl`, `MedGemma` +- **Embedding** : `e5`, `bert`, `ModernBERT` +- **Taille** : `1B`, `3B`, `7B`, `8B`, `70B` +- **Quantification** : `4bit`, `8bit`, `bf16` diff --git a/docs/guides/audio.md b/docs/guides/audio.md index 2c26a8fe5..9f5b80769 100644 --- a/docs/guides/audio.md +++ b/docs/guides/audio.md @@ -200,6 +200,10 @@ Transcribe audio to text (OpenAI Whisper API compatible). - `language`: Language code (optional, auto-detected) - `response_format`: `json` or `text` +**Limits:** +- Default upload cap: 25 MiB +- Override with `--max-audio-upload-mb` + **Example:** ```bash curl http://localhost:8000/v1/audio/transcriptions \ @@ -218,6 +222,10 @@ Generate speech from text (OpenAI TTS API compatible). - `speed`: Speech speed (0.5 to 2.0) - `response_format`: `wav`, `mp3` +**Limits:** +- Default input cap: 4096 characters +- Override with `--max-tts-input-chars` + **Example:** ```bash curl http://localhost:8000/v1/audio/speech \ diff --git a/docs/guides/embeddings.md b/docs/guides/embeddings.md index 3fdd4426c..f8f1ce9eb 100644 --- a/docs/guides/embeddings.md +++ b/docs/guides/embeddings.md @@ -17,7 +17,7 @@ pip install mlx-embeddings>=0.0.5 vllm-mlx serve my-llm-model --embedding-model mlx-community/all-MiniLM-L6-v2-4bit ``` -If you don't use `--embedding-model`, the embedding model is loaded lazily on the first request. +If you don't use `--embedding-model`, the embedding model is loaded lazily on the first request, but only from the built-in request-time allowlist. ### Generate embeddings with the OpenAI SDK @@ -59,19 +59,25 @@ curl http://localhost:8000/v1/embeddings \ ## Supported Models -Any BERT, XLM-RoBERTa, or ModernBERT model from HuggingFace that is compatible with mlx-embeddings: +Supported request-time models: | Model | Use Case | Size | |-------|----------|------| | `mlx-community/all-MiniLM-L6-v2-4bit` | Fast, compact | Small | | `mlx-community/embeddinggemma-300m-6bit` | High quality | 300M | | `mlx-community/bge-large-en-v1.5-4bit` | Best for English | Large | +| `mlx-community/multilingual-e5-small-mlx` | Multilingual retrieval | Small | +| `mlx-community/multilingual-e5-large-mlx` | Multilingual retrieval | Large | +| `mlx-community/bert-base-uncased-mlx` | General BERT baseline | Base | +| `mlx-community/ModernBERT-base-mlx` | ModernBERT baseline | Base | + +Other embedding models require `--embedding-model` at server startup. ## Model Management ### Lazy loading -By default, the embedding model is loaded on the first `/v1/embeddings` request. You can switch models between requests and the previous model will be unloaded automatically. +By default, the embedding model is loaded on the first `/v1/embeddings` request. You can switch between the supported request-time models above, and the previous model will be unloaded automatically. ### Pre-loading at startup @@ -93,7 +99,7 @@ Create embeddings for the given input text(s). | Field | Type | Required | Description | |-------|------|----------|-------------| -| `model` | string | Yes | Model name from HuggingFace | +| `model` | string | Yes | Supported embedding model ID, or the startup-pinned model when `--embedding-model` is used | | `input` | string or list[string] | Yes | Text(s) to embed | **Response:** @@ -137,7 +143,7 @@ pip install mlx-embeddings>=0.0.5 ### Model not found -Make sure the model name matches a HuggingFace repository compatible with mlx-embeddings. You can pre-download models: +Make sure the model name matches one of the supported request-time IDs above, or start the server with `--embedding-model` to pin a custom model. You can pre-download supported models: ```bash huggingface-cli download mlx-community/all-MiniLM-L6-v2-4bit diff --git a/docs/guides/moe-top-k.md b/docs/guides/moe-top-k.md new file mode 100644 index 000000000..211f7c180 --- /dev/null +++ b/docs/guides/moe-top-k.md @@ -0,0 +1,122 @@ +# MoE top_k override (`--moe-top-k`) + +Reduces the number of experts activated per token in Mixture-of-Experts models +like Qwen3-30B-A3B, trading a small amount of quality for meaningfully higher +decode throughput. + +> **Status:** opt-in flag. Default behaviour is unchanged. Quality numbers +> below are for Qwen3-30B-A3B-4bit on M4 Max 128 GB — verify on your model +> before shipping this to production workloads. + +## What it does + +Qwen3-30B-A3B is trained with `top_k=8` — every token picks 8 out of 128 +experts. On Apple Silicon at batch=1 decode the expert matmul (`SwitchGLU`) +is the single biggest chunk of each layer's compute, and that cost scales +roughly linearly with `top_k`. Lowering `top_k` at inference time has been +shown (LExI 2025, Lynx 2024) to preserve most of the trained quality while +cutting decode time materially. + +`--moe-top-k N` iterates every layer of the loaded model, and on each layer +that has `.mlp.switch_mlp` (i.e. a sparse-MoE block) sets `top_k = N`. Dense +layers and dense models are untouched — the flag is a no-op for them. + +## Usage + +```bash +# Server +vllm-mlx serve mlx-community/Qwen3-30B-A3B-4bit \ + --continuous-batching \ + --moe-top-k 4 + +# Bench +vllm-mlx bench mlx-community/Qwen3-30B-A3B-4bit --moe-top-k 4 +``` + +The flag is rejected if `N` is greater than the model's trained `top_k` +(it only makes sense to lower, never to raise). + +## Measured impact + +### Decode throughput (M4 Max 128 GB, batch=1, greedy) + +| top_k | tok/s | vs baseline | +|---:|---:|---:| +| 8 (baseline) | 126.5 | — | +| 6 | 136.1 | +7.6% | +| 5 | 140.3 | +10.9% | +| 4 | 147.3 | +16.5% | + +### Quality (Qwen3-30B-A3B-4bit, lm-evaluation-harness, MLX backend) + + + +| top_k | MMLU (acc) | GSM8K (exact match) | Δ vs baseline | +|---:|---:|---:|---:| +| 8 | TBD | TBD | — | +| 6 | TBD | TBD | TBD | +| 5 | TBD | TBD | TBD | +| 4 | TBD | TBD | TBD | + +MMLU: 200 randomly-selected samples, 0-shot. +GSM8K: 100 randomly-selected samples, 0-shot, exact-match strict. + +These numbers are **directional** — full suites are larger and would shift +the absolute accuracy but not the relative delta between configs by much. + +### Greedy output parity + +With `top_k=4` on the 4-bit checkpoint we observed **identical first 16 +generated tokens** vs the baseline across every probe prompt we tried. This +suggests top_k=4 does not change the argmax in the early decode steps — the +model is internally robust to dropping half its activated experts. + +At `top_k=3` or lower quality would start to degrade visibly (not measured +here; inferred from LExI paper), so the flag is intentionally not lowered +below 1 at the config validation layer but the recommended floor for +production is `top_k=4`. + +## When to use it, when not to + +Use it when: +- You run a Qwen3 MoE (or compatible: Qwen3.5 MoE, Gemma-MoE) and single-user + decode throughput is your bottleneck. +- You have a workload where a small quality drop is acceptable in exchange + for a visible latency improvement. +- You're deploying on memory-bandwidth-bound hardware (M-series Apple Silicon) + where expert gather dominates per-step decode time. + +Skip it when: +- You serve dense models — flag is a no-op, adds nothing. +- You care about top-1% leaderboard accuracy on eval suites. +- You run long chain-of-thought / "thinking mode" generations where the + quality cliff may be steeper than 0-shot MMLU suggests. + +## Stacking with other optimizations + +This flag composes with quantization. On Qwen3-30B-A3B-4bit our measured +stack is: + +- 4-bit + top_k=8: 126.5 tok/s (baseline) +- 4-bit + top_k=4: 147.3 tok/s (+16.5%) +- 3-bit + top_k=8: 138.6 tok/s (+9.6%) +- 3-bit + top_k=6: 147.1 tok/s (+16.3%) — quality divergence measurable +- 3-bit + top_k=4: 157.3 tok/s (+24%) — **output quality breaks** (model answered a different question in our smoke test) + +3-bit + top_k=4 compounded the numerical error past the point where the +argmax is stable. Stick to at most one aggressive knob: either 4-bit + top_k=4 +or 3-bit + top_k=6. Both give approximately the same tok/s (~147) with very +different quality profiles. + +## Internals + +- Patch helper: `vllm_mlx.scheduler.apply_moe_top_k_override(model, k)` +- Applied in `Scheduler.__init__` after the model is loaded. +- Tests: `tests/test_moe_top_k.py` — covers dense models, mixed architectures, + and validation paths. + +## References + +- LExI: Layer-Adaptive Active Experts, [arXiv 2509.02753](https://arxiv.org/html/2509.02753) +- Not All Experts are Equal (NAEE), [ACL 2024](https://aclanthology.org/2024.acl-long.334.pdf) +- SwiftLM (`SWIFTLM_TOP_K` env knob prior art), [github.com/SharpAI/SwiftLM](https://github.com/SharpAI/SwiftLM) diff --git a/docs/guides/server.md b/docs/guides/server.md index 12badf87f..55993bd0d 100644 --- a/docs/guides/server.md +++ b/docs/guides/server.md @@ -2,6 +2,8 @@ vllm-mlx provides a FastAPI server with full OpenAI API compatibility. +By default the server binds only to `127.0.0.1`. Use `--host 0.0.0.0` only when you intentionally want to expose it beyond the local machine. + ## Starting the Server ### Simple Mode (Default) @@ -28,12 +30,23 @@ Memory-efficient caching for production: vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous-batching --use-paged-cache ``` +### With Server-Wide Chat Template Defaults + +Set server defaults for chat template kwargs. Request-level `chat_template_kwargs` +values still win per key. + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit \ + --reasoning-parser qwen3 \ + --default-chat-template-kwargs '{"enable_thinking": false}' +``` + ## Server Options | Option | Description | Default | |--------|-------------|---------| | `--port` | Server port | 8000 | -| `--host` | Server host | 0.0.0.0 | +| `--host` | Server host | 127.0.0.1 | | `--api-key` | API key for authentication | None | | `--rate-limit` | Requests per minute per client (0 = disabled) | 0 | | `--timeout` | Request timeout in seconds | 300 | @@ -43,8 +56,10 @@ vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous | `--cache-memory-mb` | Cache memory limit in MB | Auto | | `--cache-memory-percent` | Fraction of RAM for cache | 0.20 | | `--max-tokens` | Default max tokens | 32768 | +| `--max-request-tokens` | Maximum `max_tokens` accepted from API clients | 32768 | | `--default-temperature` | Default temperature when not specified | None | | `--default-top-p` | Default top_p when not specified | None | +| `--default-chat-template-kwargs` | Default chat template kwargs used when request `chat_template_kwargs` is omitted (JSON object) | None | | `--stream-interval` | Tokens per stream chunk | 1 | | `--mcp-config` | Path to MCP config file | None | | `--reasoning-parser` | Parser for reasoning models (`qwen3`, `deepseek_r1`) | None | diff --git a/docs/guides/warm-prompts.md b/docs/guides/warm-prompts.md new file mode 100644 index 000000000..ad46acd9f --- /dev/null +++ b/docs/guides/warm-prompts.md @@ -0,0 +1,196 @@ +# Warm Prompts + +Pre-populate the prefix cache at server startup so the **first** request +an agent sends hits a warm cache instead of paying the full prefill for +its multi-kilobyte system prompt. + +## When to use this + +Agent workloads — proxies to coding/reasoning assistants, MCP servers, +multi-agent orchestrators — always send the same system prompt. Today +the first request from a cold server pays the full prefill for that +system. On a multi-billion-parameter model that is several seconds of +TTFT, landing exactly when a user is waiting for their new agent to +respond for the first time. + +If you already know your agents' system prompts at deploy time, write +them to a JSON file and point `--warm-prompts` at it. The server +runs a `max_tokens=1` chat completion for each at startup, the KV +state lands in the prefix cache, and the first real request matches +via strict-prefix. + +Requires `--continuous-batching` (the prefix cache lives there). + +## Quick example + +```bash +# Write the agents you care about once +cat > ~/.config/vllm-mlx/agents.json <<'JSON' +[ + [{"role": "system", "content": "You are a code assistant..."}] +] +JSON + +# Point the server at it +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --continuous-batching \ + --warm-prompts ~/.config/vllm-mlx/agents.json +``` + +On start you'll see: + +``` +[lifespan] Warm-up done (strict-prefix): 1 completed, 0 skipped, + 1431 prompt tokens in 0.2s +``` + +The first real request that shares the warmed system prompt hits the +cache with `tokens_saved` close to the warm-up prompt length. + +## File format + +A top-level JSON list. Each entry is itself a list of chat messages — +same shape as `messages` in `/v1/chat/completions`. + +```json +[ + [ + {"role": "system", "content": "You are a code assistant..."} + ], + [ + {"role": "system", "content": "You are a senior code reviewer..."} + ], + [ + {"role": "system", "content": "You are a planner..."}, + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "Hello, what are we planning?"} + ] +] +``` + +Single-message system prompts are the common case. Multi-turn histories +are supported for scenarios where you want to warm a specific +conversation start (few-shot examples, a running assistant persona). + +## Sizing + +Warm-up prompts are processed **concurrently** via `asyncio.gather`, +so N entries fire N concurrent prefills at startup. Each prefill +allocates KV cache for its prompt length. + +**Recommended: 1–3 entries.** That covers the hot paths for typical +agent deployments (one persona per entry). A very large warm-prompts +file on a memory-tight model can exhaust headroom at boot. + +If you need to warm dozens of personas, open an issue with your +workload and we can add a `--warm-prompts-concurrency=N` cap. + +## Benchmarks + +**Setup.** M4 Max, 128 GB unified memory. Two separate servers per +measurement (cold vs warm), isolated cold start. `long` prompt set +(~2.5k user tokens) prepended with a ~1.7k-token system prompt to +match the warm-up prompt. `max_tokens=128`. bench-serve with +`--skip-preflight-token-count` so the count_prompt_tokens preflight +does not pollute the cache. + +| Model | conc | cold TTFT | warm TTFT | Speedup | +|-------|-----:|----------:|----------:|--------:| +| Qwen3-0.6B-8bit | 1 | 563 ms | 419 ms | 1.34x | +| Qwen3-0.6B-8bit | 4 | 1 723 ms | 1 282 ms | 1.34x | +| Qwen3-0.6B-8bit | 8 | 3 708 ms | 2 661 ms | 1.39x | +| Llama-3.2-3B-Instruct-4bit | 1 | 1 754 ms | 1 060 ms | 1.65x | +| Llama-3.2-3B-Instruct-4bit | 4 | 5 926 ms | 3 945 ms | 1.50x | +| Llama-3.2-3B-Instruct-4bit | 8 | 15 161 ms | 9 820 ms | 1.54x | +| Qwen3-4B-4bit | 1 | 4 937 ms | 2 191 ms | 2.25x | +| Qwen3-4B-4bit | 4 | 12 535 ms | 9 623 ms | 1.30x | +| Qwen3-4B-4bit | 8 | 38 148 ms | 23 878 ms | 1.60x | +| Qwen3.6-35B-A3B-4bit (MoE/hybrid) | 1 | 2 400 ms | 1 603 ms | 1.50x | +| Qwen3.6-35B-A3B-4bit | 4 | 8 735 ms | 6 054 ms | 1.44x | +| Qwen3.6-35B-A3B-4bit | 8 | 22 419 ms | 14 409 ms | 1.56x | + +All 12 configurations improve. TTFT savings are largest when the +prompt-to-total ratio is highest (conc=1, long system prompt) and +still meaningful under concurrent load. + +**Generation tok/s** is neutral (within ±5%) for the dense models. +Qwen3.6-35B-A3B (MoE) shows a 20–35% decode drop at conc ≥ 4 that +appears to be MoE routing interaction with batched scheduling. TTFT +savings still dominate end-to-end latency on agent workloads, but +note this if your workflow is heavily decode-bound at high +concurrency. + +## How it works + +The naive warm-up — render the chat template with a placeholder user +message and cache the tokens — does not work for hybrid SSM+attention +models (Qwen3.5-MoE, Qwen3.6-MoE). Their cache layers include SSM +state that cannot be trimmed, so `memory_cache.py` disables LCP +matching. The placeholder user content diverges from real user +content, and a tokens-level cached entry is no longer a strict prefix +of any real request. + +The warmer here renders the chat template **twice** with two distinct +user contents (`"__PROBE_A__"` and `"__PROBE_B__"`), finds the +character position where the two strings diverge, and truncates the +first rendering at that boundary. That truncated string — everything +up to the point where user content gets inserted — is what goes to +the engine. + +Because the engine's real-request path also renders the template with +`tokenize=False` and then lets the tokenizer encode the result, the +warm-up's tokens are guaranteed to be a strict prefix of any real +request with a matching system and empty chat history. Strict prefix +matches work on every cache layer type, including the hybrid paths +where LCP is disabled. + +## Admin + +### Clear the in-memory prefix cache + +```bash +curl -X DELETE http://localhost:8000/v1/cache/prefix +``` + +If the server was started with `--warm-prompts`, the warm-up re-runs +in the background after clear. The response returns immediately without +waiting for re-warm. + +Response: + +```json +{"status": "cleared", "rewarm_scheduled": true} +``` + +### Inspect cache state + +```bash +curl http://localhost:8000/v1/status | jq '.cache' +``` + +After startup with warm-prompts you will see `entry_count > 0` before +the first user request. + +## Benchmarking your own setup + +To measure the impact on your model and prompts, use `bench-serve`: + +```bash +# Cold: no warm-prompts +vllm-mlx serve MODEL --continuous-batching & +vllm-mlx bench-serve --prompts long --concurrency 1,4 \ + --system-prompt-file my-system.txt --tag cold \ + --output cold.csv --format csv + +# Warm: same server config + --warm-prompts +vllm-mlx serve MODEL --continuous-batching \ + --warm-prompts ~/.config/vllm-mlx/agents.json & +vllm-mlx bench-serve --prompts long --concurrency 1,4 \ + --system-prompt-file my-system.txt --tag warm \ + --output warm.csv --format csv +``` + +`--skip-preflight-token-count` is auto-enabled when +`--system-prompt-file` is set, so the `count_prompt_tokens` preflight +does not pollute the cache. Compare `cold.csv` and `warm.csv` for +your workload. diff --git a/docs/reference/cli.md b/docs/reference/cli.md index 85c8d9c35..5921822c9 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -24,7 +24,7 @@ vllm-mlx serve [options] |--------|-------------|---------| | `--served-model-name` | Custom model name exposed through the OpenAI API. If not set, the model path is used as the name. | None | | `--port` | Server port | 8000 | -| `--host` | Server host | 0.0.0.0 | +| `--host` | Server host | 127.0.0.1 | | `--api-key` | API key for authentication | None | | `--rate-limit` | Requests per minute per client (0 = disabled) | 0 | | `--timeout` | Request timeout in seconds | 300 | @@ -35,6 +35,7 @@ vllm-mlx serve [options] | `--no-memory-aware-cache` | Use legacy entry-count cache | False | | `--use-paged-cache` | Enable paged KV cache | False | | `--max-tokens` | Default max tokens | 32768 | +| `--max-request-tokens` | Maximum `max_tokens` accepted from API clients | 32768 | | `--stream-interval` | Tokens per stream chunk | 1 | | `--mcp-config` | Path to MCP config file | None | | `--paged-cache-block-size` | Tokens per cache block | 64 | @@ -42,6 +43,9 @@ vllm-mlx serve [options] | `--max-num-seqs` | Max concurrent sequences | 256 | | `--default-temperature` | Default temperature when not specified in request | None | | `--default-top-p` | Default top_p when not specified in request | None | +| `--default-chat-template-kwargs` | Default chat template kwargs applied when request `chat_template_kwargs` is omitted (JSON object) | None | +| `--max-audio-upload-mb` | Maximum uploaded audio size for `/v1/audio/transcriptions` | 25 | +| `--max-tts-input-chars` | Maximum text length accepted by `/v1/audio/speech` | 4096 | | `--reasoning-parser` | Parser for reasoning models (`qwen3`, `deepseek_r1`) | None | | `--embedding-model` | Pre-load an embedding model at startup | None | | `--enable-auto-tool-choice` | Enable automatic tool calling | False | @@ -84,6 +88,11 @@ vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit # Reasoning model (separates thinking from answer) vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 +# Disable server-wide thinking by default (request-level chat_template_kwargs still override) +vllm-mlx serve mlx-community/Qwen3-8B-4bit \ + --reasoning-parser qwen3 \ + --default-chat-template-kwargs '{"enable_thinking": false}' + # DeepSeek reasoning model vllm-mlx serve mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit --reasoning-parser deepseek_r1 diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index dcdff9d78..e8d18b050 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -6,11 +6,13 @@ | Option | Description | Default | |--------|-------------|---------| -| `--host` | Server host address | `0.0.0.0` | +| `--host` | Server host address | `127.0.0.1` | | `--port` | Server port | `8000` | | `--max-tokens` | Default max tokens | `32768` | +| `--max-request-tokens` | Maximum `max_tokens` accepted from API clients | `32768` | | `--default-temperature` | Default temperature when not specified in request | None | | `--default-top-p` | Default top_p when not specified in request | None | +| `--default-chat-template-kwargs` | Default chat template kwargs used when request `chat_template_kwargs` is omitted (JSON object) | None | ### Security Options @@ -20,6 +22,8 @@ | `--rate-limit` | Requests per minute per client (0 = disabled) | `0` | | `--timeout` | Request timeout in seconds | `300` | | `--enable-metrics` | Expose Prometheus metrics on `/metrics` | `false` | +| `--max-audio-upload-mb` | Maximum uploaded audio size for `/v1/audio/transcriptions` | `25` | +| `--max-tts-input-chars` | Maximum text length accepted by `/v1/audio/speech` | `4096` | ### Batching Options diff --git a/docs/reference/models.md b/docs/reference/models.md index d378de003..0ca34a105 100644 --- a/docs/reference/models.md +++ b/docs/reference/models.md @@ -55,7 +55,7 @@ Browse thousands of pre-optimized models at: **https://huggingface.co/mlx-commun | Model Family | Example Models | |--------------|----------------| | **BERT** | `mlx-community/bert-base-uncased-mlx` | -| **XLM-RoBERTa** | `mlx-community/multilingual-e5-small-mlx`, `multilingual-e5-large-mlx` | +| **XLM-RoBERTa** | `mlx-community/multilingual-e5-small-mlx`, `mlx-community/multilingual-e5-large-mlx` | | **ModernBERT** | `mlx-community/ModernBERT-base-mlx` | ## Audio Models (via mlx-audio) diff --git a/docs/zh/benchmarks/README.md b/docs/zh/benchmarks/README.md new file mode 100644 index 000000000..62c260cde --- /dev/null +++ b/docs/zh/benchmarks/README.md @@ -0,0 +1,63 @@ +# 基准测试 + +vllm-mlx 在 Apple Silicon 上的性能基准测试。 + +## 基准测试类型 + +- [LLM 基准测试](llm.md) - 文本生成性能 +- [图像基准测试](image.md) - 图像理解性能 +- [视频基准测试](video.md) - 视频理解性能 + +## 常用命令 + +```bash +# LLM benchmark +vllm-mlx-bench --model mlx-community/Qwen3-0.6B-8bit + +# Image benchmark +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit + +# Video benchmark +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video +``` + +## 独立测试默认值 + +独立基准测试脚本内置了默认模型,可以直接运行: + +```bash +python tests/test_continuous_batching.py +python tests/test_prefix_cache.py +``` + +默认模型: +- `tests/test_continuous_batching.py` 对应 `mlx-community/Qwen3-8B-6bit` +- `tests/test_prefix_cache.py` 对应 `mlx-community/Qwen3-0.6B-8bit` + +如需测试其他模型,使用可选的 `--model` 参数: + +```bash +python tests/test_continuous_batching.py --model mlx-community/Qwen3-0.6B-8bit +python tests/test_prefix_cache.py --model mlx-community/Qwen3-8B-6bit +``` + +## 硬件配置 + +以下 Apple Silicon 配置已收录基准测试结果: + +| 芯片 | 内存 | Python | +|------|--------|--------| +| Apple M4 Max | 128 GB unified | 3.13 | +| Apple M1 Max | 64 GB unified | 3.12 | + +不同 Apple Silicon 芯片的测试结果会有所差异。 + +## 贡献基准测试 + +如果您使用的是其他 Apple Silicon 芯片,欢迎分享您的测试结果: + +```bash +vllm-mlx-bench --model mlx-community/Qwen3-0.6B-8bit --output results.json +``` + +请在 [GitHub Issues](https://github.com/waybarrios/vllm-mlx/issues) 中提交您的结果。 diff --git a/docs/zh/benchmarks/audio.md b/docs/zh/benchmarks/audio.md new file mode 100644 index 000000000..c517a4700 --- /dev/null +++ b/docs/zh/benchmarks/audio.md @@ -0,0 +1,158 @@ +# 音频基准测试 + +## 语音转文本 (STT) 基准测试 + +### 运行 STT 基准测试 + +```bash +# Run with default test audio +python examples/benchmark_audio.py --stt + +# Run with your own audio file +python examples/benchmark_audio.py --stt --audio path/to/audio.wav +``` + +### 测试结果(M4 Max,128GB) + +**测试音频:** 46.7 秒的合成语音 + +| Model | Parameters | Load Time | Transcribe Time | RTF* | +|-------|------------|-----------|-----------------|------| +| whisper-tiny | 39M | 0.34s | 0.24s | **197x** | +| whisper-small | 244M | 0.18s | 0.47s | **98x** | +| whisper-medium | 769M | 0.35s | 1.15s | **41x** | +| whisper-large-v3 | 1.5B | 0.50s | 1.96s | **24x** | +| whisper-large-v3-turbo | 809M | 0.12s | 0.86s | **55x** | + +*RTF = 实时倍率(值越高速度越快)。RTF 为 100x 表示 1 分钟的音频约在 0.6 秒内转录完成。* + +### 测试结果(M1 Max,64GB) + +使用 Parakeet 进行 STT(默认环境,Whisper 因 numpy 依赖版本不匹配而不可用): + +| Model | Load Time | Transcribe Time | RTF | +|-------|-----------|-----------------|-----| +| parakeet-tdt-0.6b-v2 | 0.28s | 1.01s | **9.9x** | +| parakeet-tdt-0.6b-v3 | 0.30s | 0.19s | **52.7x** | + +使用 Whisper 进行 STT(显式指定 `numpy==2.3.5` 并搭配 `uv run --no-sync`): + +| Model | Load Time | Transcribe Time | RTF | +|-------|-----------|-----------------|-----| +| whisper-tiny | 4.02s | 1.05s | **9.5x** | +| whisper-small | 10.15s | 1.03s | **9.7x** | +| whisper-medium | 22.96s | 2.20s | **4.6x** | +| whisper-large-v3 | 38.34s | 0.96s | **10.5x** | +| whisper-large-v3-turbo | 21.79s | 0.70s | **14.3x** | +| parakeet-tdt-0.6b-v2 | 0.47s | 0.18s | **54.4x** | +| parakeet-tdt-0.6b-v3 | 1.13s | 0.18s | **54.6x** | + +### 模型推荐 + +| 使用场景 | 推荐模型 | 原因 | +|----------|----------|------| +| **实时转录** | whisper-tiny | 速度最快(197x RTF),延迟低 | +| **通用场景** | whisper-large-v3-turbo | 速度(55x)与质量兼顾,综合表现最佳 | +| **最高精度** | whisper-large-v3 | 准确率最高,支持 99 种以上语言 | +| **低内存** | whisper-small | 244M 参数,质量良好 | + +### 转录质量 + +所有模型均能正确转录测试音频。示例输出: + +``` +Input text: +"Welcome to this comprehensive speech to text demonstration. +This audio sample is designed to test the accuracy and speed of various speech recognition models. +The quick brown fox jumps over the lazy dog..." + +Whisper-large-v3 output: +"Welcome to this comprehensive speech to text demonstration. +This audio sample is designed to test the accuracy and speed of various speech recognition models. +The quick brown fox jumps over the lazy dog..." (identical) +``` + +### 支持的语言 + +Whisper 模型支持 99 种以上语言,包括: +- 英语、西班牙语、法语、德语、意大利语、葡萄牙语 +- 中文(普通话、粤语)、日语、韩语 +- 阿拉伯语、印地语、俄语、土耳其语、乌克兰语 +- 以及更多语言 + +## 文本转语音 (TTS) 基准测试 + +### 运行 TTS 基准测试 + +```bash +python examples/benchmark_audio.py --tts +``` + +### 测试结果(M4 Max,128GB) + +**测试内容:** 为 3 段文本样本(短、中、长)生成音频 + +| Model | Load Time | Chars/sec | RTF* | +|-------|-----------|-----------|------| +| Kokoro-82M-bf16 | 0.8s | 350+ | **22x** | +| Kokoro-82M-4bit | 0.4s | 320+ | **20x** | + +*RTF = 实时倍率。RTF 为 22x 表示 1 秒的音频约在 0.045 秒内生成完毕。* + +### TTS 测试结果(M1 Max,64GB) + +| Model | Load Time | Avg Chars/s | Avg RTF | +|-------|-----------|-------------|---------| +| Kokoro-82M-bf16 | 2.81s | 176.0 | **11.9x** | +| Kokoro-82M-4bit | 0.22s | 225.6 | **15.5x** | + +### TTS 质量 + +Kokoro 可生成自然流畅的语音,具备以下特性: +- 11 种内置音色(男声与女声) +- 支持 8 种语言(英语、西班牙语、法语、日语、中文、意大利语、葡萄牙语、印地语) +- 82M 参数,轻量且高效 + +## 音频处理基准测试 + +### SAM-Audio(音源分离) + +**测试内容:** 从 30 秒摇滚歌曲中分离鼓声 + +| Metric | Value | +|--------|-------| +| Model | sam-audio-large-fp16 | +| Processing time | ~20s | +| Peak memory | ~27 GB | +| Output sample rate | 48000 Hz | + +## 运行全部音频基准测试 + +```bash +# Run all benchmarks +python examples/benchmark_audio.py --all + +# Or run individually +python examples/benchmark_audio.py --stt +python examples/benchmark_audio.py --tts +``` + +## mlx-community 上的可用模型 + +### STT 模型 +- `mlx-community/whisper-tiny-mlx` +- `mlx-community/whisper-small-mlx` +- `mlx-community/whisper-medium-mlx` +- `mlx-community/whisper-large-v3-mlx` +- `mlx-community/whisper-large-v3-turbo` +- `mlx-community/parakeet-tdt-0.6b-v2` +- `mlx-community/parakeet-tdt-0.6b-v3` + +### TTS 模型 +- `mlx-community/Kokoro-82M-bf16`(推荐) +- `mlx-community/Kokoro-82M-4bit` +- `mlx-community/chatterbox-turbo-fp16` +- `mlx-community/VibeVoice-Realtime-0.5B-4bit` + +### 音频处理 +- `mlx-community/sam-audio-large-fp16` diff --git a/docs/zh/benchmarks/image.md b/docs/zh/benchmarks/image.md new file mode 100644 index 000000000..d978cfd29 --- /dev/null +++ b/docs/zh/benchmarks/image.md @@ -0,0 +1,138 @@ +# 图像基准测试 + +## 运行图像基准测试 + +```bash +# 完整基准测试(10 种分辨率) +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit + +# 快速基准测试(4 种分辨率) +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --quick +``` + +## 测试结果 - Qwen3-VL-8B-Instruct-4bit(M4 Max,128GB) + +| Resolution | Pixels | Time | Tokens | Speed | +|------------|--------|------|--------|-------| +| 224x224 | 50K | 1.04s | 78 | 74.8 tok/s | +| 336x336 | 113K | 0.94s | 64 | 68.3 tok/s | +| 448x448 | 201K | 1.45s | 70 | 48.1 tok/s | +| 512x512 | 262K | 1.58s | 99 | 62.8 tok/s | +| 672x672 | 452K | 1.83s | 83 | 45.3 tok/s | +| 768x768 | 590K | 2.05s | 91 | 44.3 tok/s | +| 896x896 | 803K | 2.61s | 90 | 34.5 tok/s | +| 1024x1024 | 1.0M | 2.79s | 76 | 27.2 tok/s | +| 1280x720 | 922K | 2.97s | 96 | 32.4 tok/s | +| 1920x1080 | 2.1M | 6.30s | 89 | 14.1 tok/s | + +**摘要:** 所有分辨率的平均速度为 45.2 tok/s。最快为 224x224(74.8 tok/s),最慢为 1920x1080(14.1 tok/s)。 + +## 测试结果 - Qwen3-VL-8B-Instruct-4bit(M1 Max,64GB) + +本地 MLLM 基准测试: + +| Resolution | Pixels | Time | Tokens | Speed | +|------------|--------|------|--------|-------| +| 224x224 | 50K | 1.84s | 78 | 42.5 tok/s | +| 448x448 | 201K | 2.28s | 70 | 30.7 tok/s | +| 768x768 | 590K | 4.39s | 91 | 20.7 tok/s | +| 1024x1024 | 1.0M | 6.41s | 76 | 11.9 tok/s | + +| Configs | Total Time (s) | Total Tokens | Aggregate Tok/s | +|---------|-----------------|--------------|-----------------| +| 4 | 14.92 | 315 | 21.1 | + +## 测试结果 - Qwen3-VL-4B-Instruct-3bit 服务端(M1 Max,64GB) + +| Resolution | Pixels | Time | Tokens | Speed | +|------------|--------|------|--------|-------| +| 224x224 | 50K | 1.65s | 113 | 68.4 tok/s | +| 448x448 | 201K | 2.09s | 120 | 57.5 tok/s | +| 768x768 | 590K | 2.93s | 106 | 36.2 tok/s | +| 1024x1024 | 1.0M | 4.12s | 100 | 24.3 tok/s | + +| Configs | Total Time (s) | Total Tokens | Aggregate Tok/s | +|---------|-----------------|--------------|-----------------| +| 4 | 10.79 | 439 | 40.7 | + +## MLLM 前缀缓存测试结果 + +``` +====================================================================== + MLLM PREFIX CACHE TEST +====================================================================== + Model: mlx-community/Qwen3-VL-4B-Instruct-3bit + Test: Verify KV cache reuse for repeated image/video + prompt combinations + Expected behavior: + - Same image + same prompt → cache HIT + - Same image + different prompt → cache MISS + - Different image + same prompt → cache MISS +---------------------------------------------------------------------- + SETUP: Loading Model +---------------------------------------------------------------------- + Model loaded in 0.11s + +---------------------------------------------------------------------- + SETUP: Creating Test Images +---------------------------------------------------------------------- + Resized: 224x224, 336x336, 512x512, 768x768 + +---------------------------------------------------------------------- + TEST 1: Image Cache - Basic Hit/Miss +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Time | Status + -------+---------------------------+----------+--------+--------+------- + 1a | First image+prompt | MISS | MISS | 0.10ms | ✓ + 1b | Same image+prompt | HIT | HIT | 0.18ms | ✓ + 1c | Different prompt | MISS | MISS | 0.01ms | ✓ + 1d | Return to original | HIT | HIT | 0.18ms | ✓ + +---------------------------------------------------------------------- + TEST 2: Different Images +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Time | Status + -------+---------------------------+----------+--------+--------+------- + 2a | Image A first request | MISS | MISS | 0.01ms | ✓ + 2b | Image B first request | MISS | MISS | 0.01ms | ✓ + 2c | Image A cached | HIT | HIT | 0.13ms | ✓ + +---------------------------------------------------------------------- + TEST 3: Image Resolutions +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Time | Status + -------+-----------------------+----------+--------+--------+------- + 3.1a | 224x224 first | MISS | MISS | 0.01ms | ✓ + 3.1b | 224x224 cached | HIT | HIT | 0.20ms | ✓ + 3.2a | 336x336 first | MISS | MISS | 0.01ms | ✓ + 3.2b | 336x336 cached | HIT | HIT | 0.21ms | ✓ + 3.3a | 512x512 first | MISS | MISS | 0.12ms | ✓ + 3.3b | 512x512 cached | HIT | HIT | 0.20ms | ✓ + 3.4a | 768x768 first | MISS | MISS | 0.12ms | ✓ + 3.4b | 768x768 cached | HIT | HIT | 0.24ms | ✓ +====================================================================== +``` + +## 缓存键策略 + +- **图像:** `hash(image_content) + hash(prompt)` + +相同图像与相同提示词始终命中缓存。不同图像或不同提示词将不命中缓存。 + +## 性能提示 + +- 分辨率越小,处理速度越快(如 224x224 对比 1920x1080) +- 请根据任务需求选择合适的分辨率 +- 批量处理尺寸相近的图像,以获得稳定的吞吐量 + +## 指标说明 + +| Metric | Description | +|--------|-------------| +| Resolution | 图像尺寸(宽 x 高) | +| Pixels | 总像素数 | +| Time | 生成时间 | +| Tokens | 生成的输出 token 数量 | +| Speed | 每秒 token 数(tok/s) | diff --git a/docs/zh/benchmarks/llm.md b/docs/zh/benchmarks/llm.md new file mode 100644 index 000000000..416bb2355 --- /dev/null +++ b/docs/zh/benchmarks/llm.md @@ -0,0 +1,271 @@ +# LLM 基准测试 + +## 运行 LLM 基准测试 + +```bash +vllm-mlx-bench --model mlx-community/Llama-3.2-1B-Instruct-4bit --prompts 5 --max-tokens 256 +``` + +## 测试结果 (M4 Max, 128GB) + +| Model | Gen Speed | TTFT* | Memory | +|-------|-----------|-------|--------| +| Qwen3-0.6B-8bit | 402.3 tok/s | 58.6 ms | 0.68 GB | +| Llama-3.2-1B-Instruct-4bit | 463.6 tok/s | 49.2 ms | 0.69 GB | +| Qwen2.5-1.5B-Instruct-4bit | 308.5 tok/s | 86.2 ms | 0.84 GB | +| Llama-3.2-3B-Instruct-4bit | 200.1 tok/s | 81.4 ms | 1.79 GB | +| Qwen3-30B-A3B-4bit | 123.9 tok/s | 126.9 ms | 16.05 GB | +| NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-6Bit | 122.9 tok/s | 72.3 ms | 23.98 GB | + +*TTFT = 首个 token 的生成时间(模型开始输出前的延迟) + +## 测试结果 (M1 Max, 64GB) + +| Model | Runs | Prompt Tok | Gen Tok | Total Time (s) | TTFT Mean (ms) | TPOT Mean (ms) | Gen Speed (tok/s) | Total Throughput (tok/s) | +|-------|------|------------|---------|-----------------|-----------------|-----------------|-------------------|--------------------------| +| Qwen3-0.6B-8bit | 5 | 56 | 1280 | 5.66 | 119.0 | 3.97 | 251.9 | 236.1 | + +## Continuous Batching 测试结果 + +| Model | Single Request | Batch (5 req) | Speedup | +|-------|----------------|---------------|---------| +| Llama-3.2-1B-Instruct-4bit | 299.1 tok/s | 613.0 tok/s | **2.05x** | +| Llama-3.2-3B-Instruct-4bit | 137.6 tok/s | 208.1 tok/s | **1.51x** | +| Qwen3-0.6B-8bit | 328.1 tok/s | 1111.8 tok/s | **3.39x** | +| Qwen3-30B-A3B-4bit | 98.1 tok/s | 233.3 tok/s | **2.38x** | +| Qwen2.5-1.5B-Instruct-4bit | 196.9 tok/s | 322.2 tok/s | **1.64x** | + +*批量处理 5 个并发请求可将 throughput 提升 1.5 到 3 倍。* + +### Continuous Batching (M1 Max, 64GB) + +| Requests | Total Tokens | Total Time (s) | Throughput (tok/s) | Requests/sec | +|----------|--------------|-----------------|--------------------|--------------| +| 5 | 315 | 0.64 | 492.5 | 7.82 | + +## Streaming 性能 + +| Model | TTFT | Generation Speed | +|-------|------|------------------| +| Llama-3.2-1B-Instruct-4bit | ~4.6ms | 218.9 tok/s | +| Llama-3.2-3B-Instruct-4bit | ~10.7ms | 93.6 tok/s | +| Qwen3-0.6B-8bit | ~3.0ms | 328.5 tok/s | +| Qwen3-30B-A3B-4bit | ~10.2ms | 98.4 tok/s | +| Qwen2.5-1.5B-Instruct-4bit | ~7.1ms | 140.3 tok/s | + +### Streaming 解码器 (M1 Max, 64GB) + +`vllm-mlx bench-detok`: + +| Tokens | Iterations | Naive Time | Streaming Time | Speedup | +|--------|------------|------------|----------------|---------| +| 742 | 5 | 1.69ms | 0.71ms | 2.39x | + +`examples/benchmark_detokenizer.py`: + +| Sequence | Tokens | decode() | Streaming | Speedup | +|----------|--------|----------|-----------|---------| +| Short | 8 | 0.029ms | 0.028ms | 1.04x | +| Medium | 103 | 0.206ms | 0.129ms | 1.59x | +| Long | 511 | 1.040ms | 0.502ms | 2.07x | +| 1K | 1191 | 2.446ms | 1.178ms | 2.08x | +| 2K | 2381 | 4.949ms | 2.356ms | 2.10x | +| 4K | 4761 | 9.887ms | 5.398ms | 1.83x | + +平均加速比:1.79x + +## Prefix Cache 测试结果 + +### Prefix Cache (M4 Max, 128GB) + +``` +====================================================================== + LLM PREFIX CACHE TEST +====================================================================== + Model: mlx-community/Qwen3-0.6B-8bit + Expected behavior: + - Same prompt → cache HIT + - Different prompt → cache MISS +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Status + -------+---------------------+----------+--------+------- + 1a | First request | MISS | MISS | ✓ + 1b | Same prompt | HIT | HIT | ✓ + 1c | Different prompt | MISS | MISS | ✓ + 1d | Return to prompt 1 | HIT | HIT | ✓ +====================================================================== +``` + +### Prefix Cache (M1 Max, 64GB) + +| Test | Expected | Actual | Time | Status | +|------|----------|--------|------|--------| +| First request | MISS | MISS | 203.5ms | PASS | +| Same prompt | HIT | HIT | 131.6ms | PASS | +| Different prompt | MISS or PREFIX_HIT | PREFIX_HIT (5 tok) | 135.3ms | PASS | + +最终缓存统计: + +| Cache Hits | Cache Misses | Hit Rate | Tokens Saved | Cached Speedup | +|------------|--------------|----------|--------------|----------------| +| 2 | 1 | 66.7% | 20 | 1.55x | + +## Paged Cache 测试结果 + +*测试:2 轮共 20 个真实推理请求,系统提示约 286 个 token* + +``` +====================================================================== + PAGED KV CACHE - REAL INFERENCE TEST +====================================================================== + +-------------------------------------------------- +Test 1: WITHOUT Paged Cache (2 rounds of 10) +-------------------------------------------------- + Time: 1.47s + Throughput: 681.2 tok/s + Cache hits: 0 + Tokens saved: 0 + +-------------------------------------------------- +Test 2: WITH Paged Cache (2 rounds of 10) +-------------------------------------------------- + Time: 1.31s + Throughput: 765.8 tok/s + + Paged Cache Stats: + Blocks allocated: 25 + Shared blocks: 4 + Cache hits: 10 + Tokens saved: 2560 + +================================================== +SUMMARY +================================================== + Without paged cache: 681.2 tok/s + With paged cache: 765.8 tok/s + + Speedup: 1.12x + Cache hits: 10 (all Round 2 requests) + Tokens saved: 2,560 (~256 tokens × 10 requests) +================================================== +``` + +### Paged KV Cache (M1 Max, 64GB) + +推理基准测试(20 个请求): + +| Mode | Time (s) | Throughput (tok/s) | +|------|----------|--------------------| +| Without paged cache | 3.43 | 291.8 | +| With paged cache | 3.42 | 292.2 | + +| Speedup | Blocks Allocated | Shared Blocks | Cache Hits | Tokens Saved | +|---------|------------------|---------------|------------|--------------| +| 1.00x | 45 | 4 | 10 | 2560 | + +真实并发推理(20 个请求): + +| Mode | Time (s) | Throughput (tok/s) | +|------|----------|--------------------| +| Without paged cache | 4.32 | 231.7 | +| With paged cache | 4.35 | 229.7 | + +| Speedup | Blocks Allocated | Shared Blocks | Cache Hits | Tokens Saved | +|---------|------------------|---------------|------------|--------------| +| 0.99x | 49 | 8 | 10 | 5120 | + +内存节省示例: + +| Scenario | Memory Savings | +|----------|----------------| +| Shared system prompts | 70.8% | +| Concurrent memory efficiency | 83.5% | +| Prefix sharing branches | 38.5% | + +## Streaming 解码器分析 + +*第 9.1 阶段调查:mlx-lm 的 `BPEStreamingDetokenizer` 与朴素 `tokenizer.decode()` 对比* + +### 背景 + +朴素方法对每个 token 调用 `decode([token])`。理论上,streaming 解码器的时间复杂度为 O(T),而朴素解码为 O(T²)。 + +### 孤立基准测试结果 + +```bash +vllm-mlx bench-detok +``` + +复用同一解码器实例时(每次使用前调用 `reset()`): + +| Sequence | Tokens | Naive decode() | Streaming | Speedup | +|----------|--------|----------------|-----------|---------| +| Short | 8 | 0.020ms | 0.019ms | 1.05x | +| Medium | 103 | 0.155ms | 0.097ms | 1.59x | +| Long | 511 | 0.752ms | 0.371ms | **2.03x** | +| 1K tokens | 1191 | 1.743ms | 0.833ms | **2.09x** | +| 2K tokens | 2381 | 3.493ms | 1.737ms | **2.01x** | + +### 关键发现:实例创建开销 + +创建新的 `BPEStreamingDetokenizer` 实例**极其昂贵**: + +``` +100 tokenizer.detokenizer calls: 5.266s (52.7ms each!) +``` + +这意味着每个请求新建一个解码器实例会增加约 **52ms 的额外开销**,从而抵消所有性能收益。 + +### 实际影响 + +集成到调度器后(每个请求一个解码器实例): + +| Metric | Naive decode() | Streaming (new instance) | +|--------|----------------|--------------------------| +| Throughput (20 req) | 681 tok/s | 275 tok/s | +| Impact | - | **慢 60%** | + +### 结论 + +由于实例创建成本过高,streaming 解码器**目前不适合**在每个请求中独立使用。朴素的 `decode([token])` 方法在实践中仍然更快。 + +**未来优化方向**:在启动时预先创建一个解码器实例池,并在请求间复用这些实例。 + +## 指标参考 + +| Metric | Description | +|--------|-------------| +| **TTFT** | Time to First Token,模型开始响应前的延迟 (ms) | +| **TPOT** | Time Per Output Token,每个生成 token 之间的间隔 (ms/token) | +| **Generation TPS** | 每秒输出 token 数 (tok/s) | +| **Processing TPS** | 每秒处理输入或提示词 token 数 (tok/s) | +| **End-to-End Latency** | 从请求发出到收到完整响应的总时间 | +| **Total Throughput** | 每秒处理的总 token 数(输入加输出) | + +## 运行基准测试 + +```bash +# Basic benchmark +vllm-mlx-bench --model mlx-community/Qwen3-0.6B-8bit + +# With more prompts +vllm-mlx-bench --model mlx-community/Qwen3-0.6B-8bit --prompts 10 + +# Save results +vllm-mlx-bench --model mlx-community/Qwen3-0.6B-8bit --output results.json + +# Continuous batching test +python tests/test_continuous_batching.py + +# Prefix cache test +python tests/test_prefix_cache.py + +# Paged cache test +python tests/test_paged_cache_real_inference.py + +# Streaming detokenizer benchmark +vllm-mlx bench-detok +vllm-mlx bench-detok mlx-community/Llama-3.2-1B-Instruct-4bit --iterations 5 +``` diff --git a/docs/zh/benchmarks/video.md b/docs/zh/benchmarks/video.md new file mode 100644 index 000000000..58bb75e3a --- /dev/null +++ b/docs/zh/benchmarks/video.md @@ -0,0 +1,128 @@ +# 视频基准测试 + +## 运行视频基准测试 + +```bash +# 完整基准测试(10 种配置,2-64 帧) +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video + +# 快速基准测试(3 种帧数) +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video --quick + +# 自定义视频 +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video --video-url https://example.com/video.mp4 +``` + +## 结果 - Qwen3-VL-8B-Instruct-4bit(M4 Max,128GB) + +| Configuration | Frames | Time | Tokens | Speed | Memory | +|---------------|--------|------|--------|-------|--------| +| 2 frames @ 0.5fps | 2 | 4.48s | 256 | 57.1 tok/s | 6.4 GB | +| 4 frames @ 1fps | 4 | 4.65s | 256 | 55.0 tok/s | 6.4 GB | +| 6 frames @ 1fps | 6 | 5.15s | 197 | 38.2 tok/s | 6.6 GB | +| 8 frames @ 2fps | 8 | 6.45s | 240 | 37.2 tok/s | 6.8 GB | +| 12 frames @ 2fps | 12 | 8.73s | 256 | 29.3 tok/s | 7.1 GB | +| 16 frames @ 2fps | 16 | 10.96s | 256 | 23.4 tok/s | 7.6 GB | +| 24 frames @ 4fps | 24 | 14.95s | 226 | 15.1 tok/s | 8.4 GB | +| 32 frames @ 4fps | 32 | 20.00s | 256 | 12.8 tok/s | 9.2 GB | +| 48 frames @ 8fps | 48 | 31.11s | 246 | 7.9 tok/s | 11.1 GB | +| 64 frames @ 8fps | 64 | 59.81s | 256 | 4.3 tok/s | 12.9 GB | + +**总结:** 2 帧时速度最快(57.1 tok/s),64 帧时速度最慢(4.3 tok/s)。内存占用从 6.4 GB 增长至 12.9 GB。 + +> **注意:** 96 帧及以上会因内存或计算资源限制,在大多数硬件上导致 GPU 超时。 + +## 结果 - Qwen3-VL-8B-Instruct-4bit(M1 Max,64GB) + +| Configuration | Frames | FPS | Time | Tokens | Speed | +|---------------|--------|-----|------|--------|-------| +| 4 frames @ 1fps | 4 | 1.0 | 8.84s | 256 | 29.0 tok/s | +| 8 frames @ 2fps | 8 | 2.0 | 13.05s | 256 | 19.6 tok/s | +| 16 frames @ 2fps | 16 | 2.0 | 21.60s | 256 | 11.9 tok/s | + +| Configs | Total Time (s) | Total Tokens | Aggregate Tok/s | +|---------|-----------------|--------------|-----------------| +| 3 | 43.48 | 768 | 17.7 | + +## 结果 - Qwen3-VL-4B-Instruct-3bit(M1 Max,64GB) + +| Configuration | Frames | FPS | Time | Tokens | Speed | +|---------------|--------|-----|------|--------|-------| +| 4 frames @ 1fps | 4 | 1.0 | 5.09s | 150 | 29.5 tok/s | +| 8 frames @ 2fps | 8 | 2.0 | 8.36s | 150 | 17.9 tok/s | +| 16 frames @ 2fps | 16 | 2.0 | 15.21s | 150 | 9.9 tok/s | + +| Configs | Total Time (s) | Total Tokens | Aggregate Tok/s | +|---------|-----------------|--------------|-----------------| +| 3 | 28.66 | 450 | 15.7 | + +## 视频缓存结果 + +``` +---------------------------------------------------------------------- + TEST 4: Video Cache - fps/max_frames in Cache Key +---------------------------------------------------------------------- + Config: fps=2.0, max_frames=16 + + Results: + Step | Description | Expected | Actual | Time | Status + -------+---------------------------+----------+--------+--------+------- + 4a | Video first request | MISS | MISS | 0.03ms | ✓ + 4b | Same video+params | HIT | HIT | 0.14ms | ✓ + 4c | Different fps (4.0) | MISS | MISS | 0.01ms | ✓ + 4d | Different max_frames (32) | MISS | MISS | 0.01ms | ✓ + 4.0.5a | fps=0.5 first | MISS | MISS | 0.01ms | ✓ + 4.0.5b | fps=0.5 cached | HIT | HIT | 0.14ms | ✓ + 4.1.0a | fps=1.0 first | MISS | MISS | 0.01ms | ✓ + 4.1.0b | fps=1.0 cached | HIT | HIT | 0.14ms | ✓ + 4.2.0a | fps=2.0 first | MISS | MISS | 0.01ms | ✓ + 4.2.0b | fps=2.0 cached | HIT | HIT | 0.14ms | ✓ + 4.4.0a | fps=4.0 first | MISS | MISS | 0.01ms | ✓ + 4.4.0b | fps=4.0 cached | HIT | HIT | 0.14ms | ✓ + +---------------------------------------------------------------------- + TEST 5: Additional Videos +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Time | Status + -------+---------------------------+----------+--------+--------+------- + 5a | Video 1 first | MISS | MISS | 0.01ms | ✓ + 5b | Video 2 first | MISS | MISS | 0.01ms | ✓ + 5c | Video 1 cached | HIT | HIT | 0.13ms | ✓ + 5d | Video 2 cached | HIT | HIT | 0.13ms | ✓ +``` + +## 缓存键策略 + +- **视频:** `hash(video_path) + hash(fps) + hash(max_frames) + hash(prompt)` + +相同视频在 fps、max_frames 和 prompt 均相同时会命中缓存。任意参数发生变化则导致缓存未命中。 + +## 性能建议 + +- FPS 越低,处理速度越快。 +- 帧数越少,内存占用越小。 +- 64 帧是实际可用的最大值。 +- 96 帧及以上会导致 GPU 超时。 + +## 帧提取参考 + +| FPS | 10s 视频 | 30s 视频 | 60s 视频 | +|-----|-----------|-----------|-----------| +| 0.5 | 5 frames | 15 frames | 30 frames | +| 1.0 | 10 frames | 30 frames | 60 frames | +| 2.0 | 20 frames | 60 frames | 120 frames* | +| 4.0 | 40 frames | 120 frames* | 240 frames* | + +*可能触及 `max_frames` 上限 + +## 指标说明 + +| Metric | 说明 | +|--------|-------------| +| Configuration | FPS 与最大帧数设置 | +| Frames | 实际提取的帧数 | +| Time | 总生成时长 | +| Tokens | 生成的输出 token 数 | +| Speed | 每秒 token 数(tok/s) | +| Memory | GPU 内存占用 | diff --git a/docs/zh/getting-started/installation.md b/docs/zh/getting-started/installation.md new file mode 100644 index 000000000..84d7d0060 --- /dev/null +++ b/docs/zh/getting-started/installation.md @@ -0,0 +1,89 @@ +# 安装 + +## 系统要求 + +- 搭载 Apple Silicon(M1/M2/M3/M4)的 macOS +- Python 3.10+ + +## 使用 uv 安装(推荐) + +```bash +git clone https://github.com/waybarrios/vllm-mlx.git +cd vllm-mlx + +uv pip install -e . +``` + +## 使用 pip 安装 + +```bash +git clone https://github.com/waybarrios/vllm-mlx.git +cd vllm-mlx + +pip install -e . +``` + +### 可选:视觉支持 + +使用 transformers 进行视频处理: + +```bash +pip install -e ".[vision]" +``` + +### 可选:音频支持(STT/TTS) + +```bash +pip install mlx-audio +``` + +### 可选:向量嵌入 + +```bash +pip install mlx-embeddings +``` + +## 安装内容说明 + +- `mlx`, `mlx-lm`, `mlx-vlm` - MLX 框架及模型库 +- `transformers`, `tokenizers` - HuggingFace 库 +- `opencv-python` - 视频处理 +- `gradio` - 对话界面 +- `psutil` - 资源监控 +- `mlx-audio`(可选)- 语音转文字与文字转语音 +- `mlx-embeddings`(可选)- 文本向量嵌入 + +## 验证安装 + +```bash +# Check CLI commands +vllm-mlx --help +vllm-mlx-bench --help +vllm-mlx-chat --help + +# Test with a small model +vllm-mlx-bench --model mlx-community/Llama-3.2-1B-Instruct-4bit --prompts 1 +``` + +## 故障排查 + +### 找不到 MLX + +请确认您使用的是 Apple Silicon 设备: +```bash +uname -m # Should output "arm64" +``` + +### 模型下载失败 + +请检查网络连接及 HuggingFace 访问权限。部分模型需要身份验证: +```bash +huggingface-cli login +``` + +### 内存不足 + +请使用更小的量化模型: +```bash +vllm-mlx serve mlx-community/Llama-3.2-1B-Instruct-4bit +``` diff --git a/docs/zh/getting-started/quickstart.md b/docs/zh/getting-started/quickstart.md new file mode 100644 index 000000000..a68d6df74 --- /dev/null +++ b/docs/zh/getting-started/quickstart.md @@ -0,0 +1,133 @@ +# 快速开始 + +## 选项 1:兼容 OpenAI 的服务器 + +启动服务器: + +```bash +# Simple mode - maximum throughput for single user +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 + +# Continuous batching - for multiple concurrent users +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous-batching +``` + +使用 OpenAI Python SDK: + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +response = client.chat.completions.create( + model="mlx-community/Llama-3.2-3B-Instruct-4bit", + messages=[{"role": "user", "content": "Hello!"}], +) +print(response.choices[0].message.content) +``` + +或使用 curl: + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"model": "default", "messages": [{"role": "user", "content": "Hello!"}]}' +``` + +## 选项 2:直接使用 Python API + +```python +from vllm_mlx.models import MLXLanguageModel + +model = MLXLanguageModel("mlx-community/Llama-3.2-3B-Instruct-4bit") +model.load() + +# Generate text +output = model.generate("What is the capital of France?", max_tokens=100) +print(output.text) + +# Streaming +for chunk in model.stream_generate("Tell me a story"): + print(chunk.text, end="", flush=True) +``` + +## 选项 3:Gradio 聊天界面 + +```bash +vllm-mlx-chat --served-model-name mlx-community/Llama-3.2-3B-Instruct-4bit +``` + +在 http://localhost:7860 打开网页界面。 + +## 多模态模型 + +如需图像或视频理解,请使用 VLM 模型: + +```bash +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit --port 8000 +``` + +```python +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}} + ] + }], + max_tokens=256 +) +``` + +## 推理模型 + +将模型的思考过程与最终答案分离: + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 +``` + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What is 17 × 23?"}] +) +print(response.choices[0].message.content) # Final answer +``` + +## Embeddings + +为语义搜索和 RAG 生成文本 embeddings: + +```bash +vllm-mlx serve mlx-community/Qwen3-4B-4bit --embedding-model mlx-community/multilingual-e5-small-mlx +``` + +```python +response = client.embeddings.create( + model="mlx-community/multilingual-e5-small-mlx", + input="Hello world" +) +``` + +## 工具调用 + +为任意支持的模型启用函数调用: + +```bash +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice --tool-call-parser mistral +``` + +## 下一步 + +- [服务器指南](../guides/server.md) - 完整的服务器配置 +- [Python API](../guides/python-api.md) - 直接使用 API +- [多模态指南](../guides/multimodal.md) - 图像与视频 +- [音频指南](../guides/audio.md) - STT 与 TTS +- [Embeddings 指南](../guides/embeddings.md) - 文本 embeddings +- [推理模型](../guides/reasoning.md) - 思考模型 +- [工具调用](../guides/tool-calling.md) - 函数调用 +- [支持的模型](../reference/models.md) - 可用模型 diff --git a/docs/zh/guides/audio.md b/docs/zh/guides/audio.md new file mode 100644 index 000000000..5c1002ceb --- /dev/null +++ b/docs/zh/guides/audio.md @@ -0,0 +1,524 @@ +# 音频支持 + +vllm-mlx 通过 [mlx-audio](https://github.com/Blaizzy/mlx-audio) 支持音频处理,提供以下功能: + +- **STT(语音转文字)**:whisper、Parakeet +- **TTS(文字转语音)**:Kokoro、Chatterbox、VibeVoice、VoxCPM +- **音频处理**:SAM-Audio(人声分离) + +## 安装 + +```bash +# 核心音频支持 +pip install mlx-audio>=0.2.9 + +# TTS 所需依赖 +pip install sounddevice soundfile scipy numba tiktoken misaki spacy num2words loguru phonemizer + +# 下载 spacy 英文模型 +python -m spacy download en_core_web_sm + +# 非英语 TTS(西班牙语、法语等)需安装 espeak-ng: +# macOS +brew install espeak-ng + +# Ubuntu/Debian +# sudo apt-get install espeak-ng +``` + +或一次性安装所有音频依赖: + +```bash +pip install vllm-mlx[audio] +python -m spacy download en_core_web_sm +brew install espeak-ng # macOS,用于非英语语言 +``` + +## 快速开始 + +### STT(语音转录) + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# 转录音频文件 +with open("audio.mp3", "rb") as f: + transcript = client.audio.transcriptions.create( + model="whisper-large-v3", + file=f, + language="en" # optional + ) +print(transcript.text) +``` + +### TTS(语音合成) + +```python +# 生成语音 +audio = client.audio.speech.create( + model="kokoro", + input="Hello, how are you?", + voice="af_heart", + speed=1.0 +) + +# 保存到文件 +with open("output.wav", "wb") as f: + f.write(audio.content) +``` + +### 人声分离(SAM-Audio) + +从背景噪声、音乐或其他声音中提取人声: + +```python +from vllm_mlx.audio import AudioProcessor + +# 加载 SAM-Audio 模型 +processor = AudioProcessor("mlx-community/sam-audio-large-fp16") +processor.load() + +# 从音频中分离语音 +result = processor.separate("meeting_with_music.mp3", description="speech") + +# 保存分离后的人声和背景音 +processor.save(result.target, "voice_only.wav") +processor.save(result.residual, "background_only.wav") +``` + +**命令行示例:** +```bash +python examples/audio_separation_example.py meeting.mp3 --play +python examples/audio_separation_example.py song.mp3 --description music -o music.wav +``` + +### 鼓声分离演示 + +使用 SAM-Audio 从摇滚歌曲中分离鼓声: + +| 音频 | 说明 | 收听 | +|-------|-------------|--------| +| 原始音频 | David Fesliyan 的"Get Ready"(30 秒,免版权) | [rock_get_ready.mp3](../../../examples/rock_get_ready.mp3) | +| 分离鼓声 | SAM-Audio 提取的鼓声 | [drums_isolated.wav](../../../examples/drums_isolated.wav) | +| 去除鼓声 | 移除鼓声后的音轨 | [rock_no_drums.wav](../../../examples/rock_no_drums.wav) | + +```bash +# 从摇滚歌曲中分离鼓声 +python examples/audio_separation_example.py examples/rock_get_ready.mp3 \ + --description "drums" \ + --output drums_isolated.wav \ + --background rock_no_drums.wav +``` + +**性能:** 在 M4 Max 上处理 30 秒音频约需 20 秒。 + +## 支持的模型 + +### STT 模型(语音转文字) + +| 模型 | 别名 | 语言数 | 速度 | 质量 | +|-------|-------|-----------|-------|---------| +| `mlx-community/whisper-large-v3-mlx` | `whisper-large-v3` | 99+ | 中等 | 最佳 | +| `mlx-community/whisper-large-v3-turbo` | `whisper-large-v3-turbo` | 99+ | 快速 | 优秀 | +| `mlx-community/whisper-medium-mlx` | `whisper-medium` | 99+ | 快速 | 良好 | +| `mlx-community/whisper-small-mlx` | `whisper-small` | 99+ | 极快 | 一般 | +| `mlx-community/parakeet-tdt-0.6b-v2` | `parakeet` | 仅英语 | 最快 | 优秀 | +| `mlx-community/parakeet-tdt-0.6b-v3` | `parakeet-v3` | 仅英语 | 最快 | 最佳 | + +**推荐方案:** +- 多语言场景:`whisper-large-v3` +- 仅英语场景:`parakeet`(速度快 3 倍) + +### TTS 模型(文字转语音) + +#### Kokoro(快速、轻量)- 推荐 + +| 模型 | 别名 | 参数量 | 支持语言 | +|-------|-------|------|-----------| +| `mlx-community/Kokoro-82M-bf16` | `kokoro` | 82M | EN、ES、FR、JA、ZH、HI、IT、PT | +| `mlx-community/Kokoro-82M-4bit` | `kokoro-4bit` | 82M | EN、ES、FR、JA、ZH、HI、IT、PT | + +**声音(11 种):** +- 美式女声:`af_heart`、`af_bella`、`af_nicole`、`af_sarah`、`af_sky` +- 美式男声:`am_adam`、`am_michael` +- 英式女声:`bf_emma`、`bf_isabella` +- 英式男声:`bm_george`、`bm_lewis` + +**语言代码:** +| 代码 | 语言 | 代码 | 语言 | +|------|----------|------|----------| +| `a` / `en` | 英语(美国) | `e` / `es` | Español | +| `b` / `en-gb` | 英语(英国) | `f` / `fr` | Français | +| `j` / `ja` | 日本語 | `z` / `zh` | 中文 | +| `i` / `it` | Italiano | `p` / `pt` | Português | +| `h` / `hi` | हिन्दी | | | + +#### Chatterbox(多语言、表现力强) + +| 模型 | 别名 | 参数量 | 支持语言 | +|-------|-------|------|-----------| +| `mlx-community/chatterbox-turbo-fp16` | `chatterbox` | 134M | 15+ 种语言 | +| `mlx-community/chatterbox-turbo-4bit` | `chatterbox-4bit` | 134M | 15+ 种语言 | + +**支持语言:** EN、ES、FR、DE、IT、PT、RU、JA、ZH、KO、AR、HI、NL、PL、TR + +#### VibeVoice(实时) + +| 模型 | 别名 | 参数量 | 适用场景 | +|-------|-------|------|----------| +| `mlx-community/VibeVoice-Realtime-0.5B-4bit` | `vibevoice` | 200M | 低延迟、仅英语 | + +#### VoxCPM(中英双语) + +| 模型 | 别名 | 参数量 | 支持语言 | +|-------|-------|------|-----------| +| `mlx-community/VoxCPM1.5` | `voxcpm` | 0.9B | ZH、EN | +| `mlx-community/VoxCPM1.5-4bit` | `voxcpm-4bit` | 200M | ZH、EN | + +### 音频处理模型 + +#### SAM-Audio(人声分离) + +| 模型 | 参数量 | 适用场景 | +|-------|------|----------| +| `mlx-community/sam-audio-large-fp16` | 3B | 最佳质量 | +| `mlx-community/sam-audio-large` | 3B | 标准 | +| `mlx-community/sam-audio-small-fp16` | 0.6B | 快速 | +| `mlx-community/sam-audio-small` | 0.6B | 轻量 | + +## API 参考 + +### POST /v1/audio/transcriptions + +将音频转录为文字(兼容 OpenAI Whisper API)。 + +**参数:** +- `file`:音频文件(mp3、wav、m4a、webm) +- `model`:模型名称或别名 +- `language`:语言代码(可选,自动检测) +- `response_format`:`json` 或 `text` + +**限制:** +- 默认上传上限:25 MiB +- 可通过 `--max-audio-upload-mb` 修改 + +**示例:** +```bash +curl http://localhost:8000/v1/audio/transcriptions \ + -F file=@audio.mp3 \ + -F model=whisper-large-v3 +``` + +### POST /v1/audio/speech + +从文字生成语音(兼容 OpenAI TTS API)。 + +**参数:** +- `model`:模型名称或别名 +- `input`:待合成文本 +- `voice`:声音 ID +- `speed`:语速(0.5 到 2.0) +- `response_format`:`wav`、`mp3` + +**限制:** +- 默认输入上限:4096 个字符 +- 可通过 `--max-tts-input-chars` 修改 + +**示例:** +```bash +curl http://localhost:8000/v1/audio/speech \ + -d '{"model": "kokoro", "input": "Hello world", "voice": "af_heart"}' \ + -H "Content-Type: application/json" \ + --output speech.wav +``` + +### GET /v1/audio/voices + +列出模型可用的声音。 + +**示例:** +```bash +curl http://localhost:8000/v1/audio/voices?model=kokoro +``` + +## 命令行示例 + +### 实时转录与字幕 + +从麦克风进行实时 STT 转录: + +```bash +# 使用 whisper-large-v3 生成字幕(最佳质量) +python examples/closed_captions.py --language es --chunk 5 + +# 使用更快的模型降低延迟 +python examples/closed_captions.py --language en --model whisper-turbo --chunk 3 + +# 基础麦克风转录(先录音再转录) +python examples/mic_transcribe.py --language es + +# 实时分块转录 +python examples/mic_realtime.py --language es --chunk 3 + +# 带语音活动检测的实时转录 +python examples/mic_live.py --language es +``` + +**依赖安装:** +```bash +pip install sounddevice soundfile numpy +``` + +### 基础 TTS + +```bash +# 简单 TTS 示例 +python examples/tts_example.py "Hello, how are you?" --play + +# 使用不同声音 +python examples/tts_example.py "Hello!" --voice am_michael --play + +# 保存到文件 +python examples/tts_example.py "Welcome to the demo" -o greeting.wav + +# 列出可用声音 +python examples/tts_example.py --list-voices +``` + +### 多语言 TTS + +```bash +# 英语(自动选择最佳模型) +python examples/tts_multilingual.py "Hello world" --play + +# 西班牙语 +python examples/tts_multilingual.py "Hola mundo" --lang es --play + +# 法语 +python examples/tts_multilingual.py "Bonjour le monde" --lang fr --play + +# 日语 +python examples/tts_multilingual.py "こんにちは" --lang ja --play + +# 中文 +python examples/tts_multilingual.py "你好世界" --lang zh --play + +# 指定模型 +python examples/tts_multilingual.py "Hello" --model chatterbox --play + +# 列出所有模型 +python examples/tts_multilingual.py --list-models + +# 列出所有语言 +python examples/tts_multilingual.py --list-languages +``` + +### 商务助手语音示例 + +使用**原生声音**预生成的常见商务场景语音样本: + +| 语言 | 声音 | 内容 | 收听 | +|----------|-------|---------|--------| +| 英语 | af_heart | "Welcome to First National Bank. How may I assist you today?" | [assistant_bank_en.wav](../../../examples/assistant_bank_en.wav) | +| 西班牙语 | ef_dora | "Gracias por llamar a servicio al cliente. Un agente le atenderá pronto." | [assistant_service_es.wav](../../../examples/assistant_service_es.wav) | +| 法语 | ff_siwis | "Bienvenue. Votre appel est important pour nous." | [assistant_callcenter_fr.wav](../../../examples/assistant_callcenter_fr.wav) | +| 中文 | zf_xiaobei | "欢迎致电技术支持中心。我们将竭诚为您服务。" | [assistant_support_zh.wav](../../../examples/assistant_support_zh.wav) | + +**使用原生声音自行生成:** +```bash +# 英语 - 银行助手(原生声音:af_heart) +python -m mlx_audio.tts.generate --model mlx-community/Kokoro-82M-bf16 \ + --text "Welcome to First National Bank. How may I assist you today?" \ + --voice af_heart --lang_code a --file_prefix assistant_bank_en + +# 西班牙语 - 客服(原生声音:ef_dora) +python -m mlx_audio.tts.generate --model mlx-community/Kokoro-82M-bf16 \ + --text "Gracias por llamar a servicio al cliente. Un agente le atendera pronto." \ + --voice ef_dora --lang_code e --file_prefix assistant_service_es + +# 法语 - 呼叫中心(原生声音:ff_siwis) +python -m mlx_audio.tts.generate --model mlx-community/Kokoro-82M-bf16 \ + --text "Bienvenue. Votre appel est important pour nous." \ + --voice ff_siwis --lang_code f --file_prefix assistant_callcenter_fr + +# 中文 - 技术支持(原生声音:zf_xiaobei) +python -m mlx_audio.tts.generate --model mlx-community/Kokoro-82M-bf16 \ + --text "欢迎致电技术支持中心。我们将竭诚为您服务。" \ + --voice zf_xiaobei --lang_code z --file_prefix assistant_support_zh +``` + +### 原生声音参考 + +| 语言 | 代码 | 声音 | +|----------|------|--------| +| 英语(美国) | `a` | af_heart、af_bella、af_nicole、am_adam、am_michael | +| 英语(英国) | `b` | bf_emma、bf_isabella、bm_george、bm_lewis | +| 西班牙语 | `e` | ef_dora、em_alex、em_santa | +| 法语 | `f` | ff_siwis | +| 中文 | `z` | zf_xiaobei、zf_xiaoni、zf_xiaoxiao、zm_yunjian、zm_yunxi | +| 日语 | `j` | jf_alpha、jf_gongitsune、jm_kumo | +| 意大利语 | `i` | if_sara、im_nicola | +| 葡萄牙语 | `p` | pf_dora、pm_alex | +| 印地语 | `h` | hf_alpha、hf_beta、hm_omega | + +## Python API + +### 直接调用(不启动服务器) + +```python +from vllm_mlx.audio import STTEngine, TTSEngine, AudioProcessor + +# Speech-to-Text +stt = STTEngine("mlx-community/whisper-large-v3-mlx") +stt.load() +result = stt.transcribe("audio.mp3") +print(result.text) + +# Text-to-Speech +tts = TTSEngine("mlx-community/Kokoro-82M-bf16") +tts.load() +audio = tts.generate("Hello world", voice="af_heart") +tts.save(audio, "output.wav") + +# Voice Separation +processor = AudioProcessor("mlx-community/sam-audio-large-fp16") +processor.load() +result = processor.separate("mixed_audio.mp3", description="speech") +processor.save(result.target, "voice_only.wav") +processor.save(result.residual, "background.wav") +``` + +### 便捷函数 + +```python +from vllm_mlx.audio import transcribe_audio, generate_speech, separate_voice + +# 快速转录 +result = transcribe_audio("audio.mp3") +print(result.text) + +# 快速 TTS +audio = generate_speech("Hello world", voice="af_heart") + +# 快速人声分离 +voice, background = separate_voice("mixed.mp3") +``` + +## 在对话中使用音频 + +在对话消息中附带音频(自动转录): + +```python +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "Summarize this audio"}, + {"type": "audio_url", "audio_url": {"url": "file://meeting.mp3"}} + ] + }] +) +``` + +## 基准测试 + +测试环境:Apple M2 Max(32GB)。 + +### TTS 基准测试(Kokoro-82M-bf16) + +| 文本长度 | 音频时长 | 生成耗时 | RTF | 字符/秒 | +|-------------|----------------|----------|-----|-----------| +| 25 字符 | 1.95s | 0.43s | 4.6x | 58.5 | +| 88 字符 | 6.00s | 0.32s | 18.6x | 272.4 | +| 117 字符 | 7.92s | 0.27s | 29.0x | 427.4 | + +**汇总:** +- 模型加载时间:约 1.0 秒 +- 平均 RTF:**17.4x**(比实时快 17 倍) +- 平均字符/秒:**252.8** + +### STT 基准测试 + +| 模型 | 加载时间 | 转录耗时(6 秒音频) | RTF | +|-------|-----------|----------------------|-----| +| whisper-small | 0.25s | 0.20s | 30.2x | +| whisper-medium | 18.1s | 0.38s | 15.5x | +| whisper-large-v3 | ~30s | ~0.6s | ~10x | +| parakeet | ~0.5s | ~0.15s | ~40x | + +**说明:** +- RTF(实时倍率)表示处理速度相对于实时的倍数 +- 首次加载包含从 HuggingFace 下载模型的时间 +- 后续加载使用已缓存的模型 + +### 按场景推荐 + +| 场景 | 推荐模型 | 原因 | +|----------|------------------|-----| +| 英语 STT(快速) | `parakeet` | RTF 40x,内存占用低 | +| 多语言 STT | `whisper-large-v3` | 支持 99+ 种语言 | +| 低延迟 STT | `whisper-small` | RTF 30x,加载快 | +| 通用 TTS | `kokoro` | RTF 17x,质量良好 | +| 低内存 TTS | `kokoro-4bit` | 4-bit 量化 | + +## 性能建议 + +1. **英语场景使用 Parakeet**,比实时快 40 倍 +2. **使用 4-bit 模型**降低内存占用 +3. **使用 SAM-Audio small**加快人声分离速度 +4. **模型缓存**,引擎采用懒加载并自动缓存 +5. **提前下载模型**,避免首次运行时的延迟 + +## 常见问题 + +### mlx-audio 未安装 +``` +pip install mlx-audio>=0.2.9 +``` + +### 模型下载缓慢 +模型首次使用时从 HuggingFace 下载。可使用 `huggingface-cli download` 提前下载: +```bash +huggingface-cli download mlx-community/whisper-large-v3-mlx +huggingface-cli download mlx-community/Kokoro-82M-bf16 +``` + +### 内存不足 +请使用较小的模型或 4-bit 量化版本: +- 用 `whisper-small-mlx` 替代 `whisper-large-v3-mlx` +- 用 `Kokoro-82M-4bit` 替代 `Kokoro-82M-bf16` +- 用 `sam-audio-small` 替代 `sam-audio-large` + +### Kokoro 多语言问题(mlx-audio 0.2.9) + +使用非英语语言(西班牙语、中文、日语等)时若出现 `ValueError: too many values to unpack`,请应用以下修复: + +```python +# Fix for mlx_audio/tts/models/kokoro/pipeline.py line 443 +# Change: +# ps, _ = self.g2p(chunk) +# To: +g2p_result = self.g2p(chunk) +ps = g2p_result[0] if isinstance(g2p_result, tuple) else g2p_result +``` + +**一键修复:** +```bash +python -c " +import os +path = os.path.join(os.path.dirname(__import__('mlx_audio').__file__), 'tts/models/kokoro/pipeline.py') +with open(path, 'r') as f: content = f.read() +old = ' ps, _ = self.g2p(chunk)' +new = ''' # Fix: handle both tuple (en) and string (zh/ja/es) returns from g2p + g2p_result = self.g2p(chunk) + ps = g2p_result[0] if isinstance(g2p_result, tuple) else g2p_result''' +if old in content: + with open(path, 'w') as f: f.write(content.replace(old, new)) + print('Fix applied!') +" +``` + +此问题的原因是英语 g2p 返回元组 `(phonemes, tokens)`,而其他语言仅返回字符串。 diff --git a/docs/zh/guides/continuous-batching.md b/docs/zh/guides/continuous-batching.md new file mode 100644 index 000000000..e4c24775c --- /dev/null +++ b/docs/zh/guides/continuous-batching.md @@ -0,0 +1,174 @@ +# Continuous Batching + +Continuous batching 在同时服务多个用户时能显著提升 throughput。 + +## 启用 Continuous Batching + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit --continuous-batching +``` + +## 与 Paged Cache 配合使用 + +启用高效内存的前缀共享: + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit --continuous-batching --use-paged-cache +``` + +## 工作原理 + +### 简单模式(默认) +- 每次处理一个请求 +- 单用户场景下 throughput 最高 +- 无 batching 额外开销 + +### Continuous Batching 模式 +- 多个请求同时处理 +- 并发用户场景下 throughput 更高 +- 每个请求存在少量额外开销 + +### Paged Cache +- KV cache 以固定大小的块存储 +- 相同的系统提示词共享同一批块 +- 10 个以上并发用户时节省 80% 以上内存 + +## 性能测试结果 + +**Continuous Batching 测试结果(M4 Max,128GB):** + +| 模型 | 单请求 | Batch(5 个请求) | 加速比 | +|-------|----------------|---------------|---------| +| Llama-3.2-1B-Instruct-4bit | 299.1 tok/s | 613.0 tok/s | **2.05x** | +| Llama-3.2-3B-Instruct-4bit | 137.6 tok/s | 208.1 tok/s | **1.51x** | +| Qwen3-0.6B-8bit | 328.1 tok/s | 1111.8 tok/s | **3.39x** | +| Qwen3-30B-A3B-4bit | 98.1 tok/s | 233.3 tok/s | **2.38x** | +| Qwen2.5-1.5B-Instruct-4bit | 196.9 tok/s | 322.2 tok/s | **1.64x** | + +*5 个并发请求的 batching 可将 throughput 提升 1.5 到 3 倍。* + +## Streaming 性能 + +**Streaming 性能(M4 Max,128GB):** + +| 模型 | TTFT | 生成速度 | +|-------|------|------------------| +| Llama-3.2-1B-Instruct-4bit | ~4.6ms | 218.9 tok/s | +| Llama-3.2-3B-Instruct-4bit | ~10.7ms | 93.6 tok/s | +| Qwen3-0.6B-8bit | ~3.0ms | 328.5 tok/s | +| Qwen3-30B-A3B-4bit | ~10.2ms | 98.4 tok/s | +| Qwen2.5-1.5B-Instruct-4bit | ~7.1ms | 140.3 tok/s | + +*TTFT = Time to First Token(首 token 延迟)* + +## Streaming 配置 + +使用 `--stream-interval` 控制 token 发送频率: + +```bash +# 每个 token 立即发送(最流畅) +vllm-mlx serve model --continuous-batching --stream-interval 1 + +# 批量发送 token(适合高延迟场景) +vllm-mlx serve model --continuous-batching --stream-interval 5 +``` + +| 值 | 行为 | +|-------|----------| +| `1` | 每个 token 立即发送 | +| `2-5` | 攒批后再发送 | +| `10+` | 最大化 throughput,输出颗粒度更大 | + +## 内存管理 + +对于大型模型,prefix cache 可能占用大量内存。内存感知缓存会自动进行管理: + +```bash +# 自动检测(使用可用内存的 20%) +vllm-mlx serve model --continuous-batching + +# 显式限制 +vllm-mlx serve model --continuous-batching --cache-memory-mb 2048 + +# 自定义百分比 +vllm-mlx serve model --continuous-batching --cache-memory-percent 0.10 +``` + +| 选项 | 说明 | +|--------|-------------| +| `--cache-memory-mb` | 以 MB 为单位设置显式上限 | +| `--cache-memory-percent` | 可用内存的占比(默认值:0.20) | +| `--no-memory-aware-cache` | 使用基于条目数量的旧式缓存 | + +## Prefix Cache + +Prefix caching 对重复提示词复用 KV cache。 + +### 工作原理 + +``` +User 1: System prompt (500 tokens) → Creates 8 blocks +User 2: Same system prompt → Shares 8 blocks (ref_count++) +User N: Same system prompt → Shares 8 blocks (ref_count++) + +Memory savings: 80%+ for 10+ concurrent users +``` + +### 缓存键策略 + +- **LLM**:`hash(prompt)` +- **图片**:`hash(image_content) + hash(prompt)` +- **视频**:`hash(video_path) + hash(fps) + hash(max_frames) + hash(prompt)` + +### 测试 Prefix Cache + +```bash +python tests/test_prefix_cache.py +``` + +``` +====================================================================== + LLM PREFIX CACHE TEST +====================================================================== + Model: mlx-community/Qwen3-0.6B-8bit + Expected behavior: + - Same prompt → cache HIT + - Different prompt → cache MISS or PREFIX_HIT (shared template tokens) +---------------------------------------------------------------------- + Results: + Step | Description | Expected | Actual | Status + -------+---------------------+----------+--------+------- + 1a | First request | MISS | MISS | PASS + 1b | Same prompt | HIT | HIT | PASS + 1c | Different prompt | MISS | MISS | PASS + 1d | Return to prompt 1 | HIT | HIT | PASS +====================================================================== +``` + +## 运行基准测试 + +```bash +# Continuous batching 基准测试 +python tests/test_continuous_batching.py + +# Prefix cache 测试 +python tests/test_prefix_cache.py +``` + +## 适用场景 + +| 场景 | 模式 | +|----------|------| +| 单用户,追求最高速度 | 简单模式(默认) | +| 多用户并发 | `--continuous-batching` | +| 大型模型(7B+) | `--continuous-batching --cache-memory-mb 2048` | +| 生产环境,提示词共享 | `--continuous-batching --use-paged-cache` | + +## 生产环境配置 + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching \ + --use-paged-cache \ + --port 8000 +``` diff --git a/docs/zh/guides/embeddings.md b/docs/zh/guides/embeddings.md new file mode 100644 index 000000000..9eb7ba7bb --- /dev/null +++ b/docs/zh/guides/embeddings.md @@ -0,0 +1,150 @@ +# Embeddings + +vllm-mlx 通过 [mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings) 支持文本 embeddings,提供与 OpenAI 兼容的 `/v1/embeddings` 接口。 + +## 安装 + +```bash +pip install mlx-embeddings>=0.0.5 +``` + +## 快速入门 + +### 启动带有 embedding 模型的服务器 + +```bash +# 启动时预加载指定的 embedding 模型 +vllm-mlx serve my-llm-model --embedding-model mlx-community/all-MiniLM-L6-v2-4bit +``` + +如果不使用 `--embedding-model`,embedding 模型会在第一次请求时按需加载,但仅限于内置的请求时许可列表中的模型。 + +### 使用 OpenAI SDK 生成 embeddings + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# 单条文本 +response = client.embeddings.create( + model="mlx-community/all-MiniLM-L6-v2-4bit", + input="Hello world" +) +print(response.data[0].embedding[:5]) # First 5 dimensions + +# 批量文本 +response = client.embeddings.create( + model="mlx-community/all-MiniLM-L6-v2-4bit", + input=[ + "I love machine learning", + "Deep learning is fascinating", + "Natural language processing rocks" + ] +) +for item in response.data: + print(f"Text {item.index}: {len(item.embedding)} dimensions") +``` + +### 使用 curl + +```bash +curl http://localhost:8000/v1/embeddings \ + -H "Content-Type: application/json" \ + -d '{ + "model": "mlx-community/all-MiniLM-L6-v2-4bit", + "input": ["Hello world", "How are you?"] + }' +``` + +## 支持的模型 + +请求时支持的模型: + +| 模型 | 适用场景 | 规模 | +|-------|----------|------| +| `mlx-community/all-MiniLM-L6-v2-4bit` | 快速、轻量 | 小 | +| `mlx-community/embeddinggemma-300m-6bit` | 高质量 | 300M | +| `mlx-community/bge-large-en-v1.5-4bit` | 英文效果最佳 | 大 | +| `mlx-community/multilingual-e5-small-mlx` | 多语言检索 | 小 | +| `mlx-community/multilingual-e5-large-mlx` | 多语言检索 | 大 | +| `mlx-community/bert-base-uncased-mlx` | 通用 BERT 基准 | 基础 | +| `mlx-community/ModernBERT-base-mlx` | ModernBERT 基准 | 基础 | + +其他 embedding 模型需要在启动服务器时通过 `--embedding-model` 指定。 + +## 模型管理 + +### 按需加载 + +默认情况下,embedding 模型在第一次收到 `/v1/embeddings` 请求时加载。你可以在上述请求时支持的模型之间切换,切换后旧模型会自动卸载。 + +### 启动时预加载 + +使用 `--embedding-model` 可在启动时加载模型。设置该参数后,只有该指定模型可用于 embeddings: + +```bash +vllm-mlx serve my-llm-model --embedding-model mlx-community/all-MiniLM-L6-v2-4bit +``` + +请求其他模型将返回 400 错误。 + +## API 参考 + +### POST /v1/embeddings + +为给定的输入文本生成 embeddings。 + +**请求体:** + +| 字段 | 类型 | 是否必填 | 描述 | +|-------|------|----------|-------------| +| `model` | string | 是 | 支持的 embedding 模型 ID,或使用 `--embedding-model` 时启动时固定的模型 | +| `input` | string 或 list[string] | 是 | 待嵌入的文本 | + +**响应:** + +```json +{ + "object": "list", + "data": [ + {"object": "embedding", "index": 0, "embedding": [0.023, -0.982, ...]}, + {"object": "embedding", "index": 1, "embedding": [0.112, -0.543, ...]} + ], + "model": "mlx-community/all-MiniLM-L6-v2-4bit", + "usage": {"prompt_tokens": 12, "total_tokens": 12} +} +``` + +## Python API + +### 不启动服务器直接使用 + +```python +from vllm_mlx.embedding import EmbeddingEngine + +engine = EmbeddingEngine("mlx-community/all-MiniLM-L6-v2-4bit") +engine.load() + +vectors = engine.embed(["Hello world", "How are you?"]) +print(f"Dimensions: {len(vectors[0])}") + +tokens = engine.count_tokens(["Hello world"]) +print(f"Token count: {tokens}") +``` + +## 常见问题 + +### mlx-embeddings 未安装 + +``` +pip install mlx-embeddings>=0.0.5 +``` + +### 找不到模型 + +请确认模型名称与上方请求时支持的 ID 之一匹配,或在启动服务器时通过 `--embedding-model` 指定自定义模型。你也可以提前下载支持的模型: + +```bash +huggingface-cli download mlx-community/all-MiniLM-L6-v2-4bit +``` diff --git a/docs/zh/guides/mcp-tools.md b/docs/zh/guides/mcp-tools.md new file mode 100644 index 000000000..3b52c0ab2 --- /dev/null +++ b/docs/zh/guides/mcp-tools.md @@ -0,0 +1,405 @@ +# MCP 与 tool calling + +vllm-mlx 支持 Model Context Protocol (MCP),用于将外部工具与 LLM 集成。 + +## tool calling 工作原理 + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ Tool Calling Flow │ +├─────────────────────────────────────────────────────────────────────┤ +│ │ +│ 1. User Request │ +│ ─────────────────► "List files in /tmp" │ +│ │ +│ 2. LLM Generates Tool Call │ +│ ─────────────────► tool_calls: [{ │ +│ name: "list_directory", │ +│ arguments: {path: "/tmp"} │ +│ }] │ +│ │ +│ 3. App Executes Tool via MCP │ +│ ─────────────────► MCP Server executes list_directory │ +│ Returns: ["file1.txt", "file2.txt"] │ +│ │ +│ 4. Tool Result Sent Back to LLM │ +│ ─────────────────► role: "tool", content: [...] │ +│ │ +│ 5. LLM Generates Final Response │ +│ ─────────────────► "The /tmp directory contains 2 files..." │ +│ │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +## 快速开始 + +### 1. 创建 MCP 配置 + +创建 `mcp.json`: + +```json +{ + "mcpServers": { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] + } + } +} +``` + +### 2. 启动带 MCP 的服务器 + +```bash +# 简单模式 +vllm-mlx serve mlx-community/Qwen3-4B-4bit --mcp-config mcp.json + +# 连续批处理 +vllm-mlx serve mlx-community/Qwen3-4B-4bit --mcp-config mcp.json --continuous-batching +``` + +### 3. 验证 MCP 状态 + +```bash +# 查看 MCP 状态 +curl http://localhost:8000/v1/mcp/status + +# 列出可用工具 +curl http://localhost:8000/v1/mcp/tools +``` + +## tool calling 示例 + +```python +import json +import httpx + +BASE_URL = "http://localhost:8000" + +# 1. Get available tools +tools_response = httpx.get(f"{BASE_URL}/v1/mcp/tools") +tools = tools_response.json()["tools"] + +# 2. Send request with tools +response = httpx.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "model": "default", + "messages": [{"role": "user", "content": "List files in /tmp"}], + "tools": tools, + "max_tokens": 1024 + } +) + +result = response.json() +message = result["choices"][0]["message"] + +# 3. Check for tool calls +if message.get("tool_calls"): + tool_call = message["tool_calls"][0] + + # 4. Execute tool via MCP + exec_response = httpx.post( + f"{BASE_URL}/v1/mcp/execute", + json={ + "server": "filesystem", + "tool": tool_call["function"]["name"], + "arguments": json.loads(tool_call["function"]["arguments"]) + } + ) + tool_result = exec_response.json() + + # 5. Send result back to LLM + messages = [ + {"role": "user", "content": "List files in /tmp"}, + message, + { + "role": "tool", + "tool_call_id": tool_call["id"], + "content": json.dumps(tool_result["result"]) + } + ] + + final_response = httpx.post( + f"{BASE_URL}/v1/chat/completions", + json={"model": "default", "messages": messages} + ) + print(final_response.json()["choices"][0]["message"]["content"]) +``` + +## MCP 接口端点 + +| 端点 | 方法 | 说明 | +|----------|--------|-------------| +| `/v1/mcp/status` | GET | 查看 MCP 状态 | +| `/v1/mcp/tools` | GET | 列出可用工具 | +| `/v1/mcp/execute` | POST | 执行工具 | + +## MCP 服务器示例 + +### 文件系统 + +```json +{ + "mcpServers": { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] + } + } +} +``` + +### GitHub + +```json +{ + "mcpServers": { + "github": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-github"], + "env": { + "GITHUB_TOKEN": "your-token" + } + } + } +} +``` + +### PostgreSQL + +```json +{ + "mcpServers": { + "postgres": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-postgres"], + "env": { + "DATABASE_URL": "postgresql://user:pass@localhost/db" + } + } + } +} +``` + +### Brave Search + +```json +{ + "mcpServers": { + "brave-search": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-brave-search"], + "env": { + "BRAVE_API_KEY": "your-key" + } + } + } +} +``` + +## 使用多个 MCP 服务器 + +```json +{ + "mcpServers": { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] + }, + "github": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-github"], + "env": { + "GITHUB_TOKEN": "your-token" + } + } + } +} +``` + +## 交互式 MCP 聊天 + +如需交互式测试 MCP: + +```bash +python examples/mcp_chat.py +``` + +## 支持的工具格式 + +vllm-mlx 支持 12 种 tool call parser,覆盖所有主流模型系列。完整的 parser 列表、别名及示例请参见 [Tool Calling](tool-calling.md)。 + +## 安全性 + +vllm-mlx 内置安全措施,防止通过 MCP 服务器进行命令注入攻击。 + +### 命令白名单 + +默认情况下,仅允许可信命令: + +| 类别 | 允许的命令 | +|----------|-----------------| +| Node.js | `npx`、`npm`、`node` | +| Python | `uvx`、`uv`、`python`、`python3`、`pip`、`pipx` | +| Docker | `docker` | +| MCP 服务器 | `mcp-server-*`(官方服务器) | + +### 屏蔽模式 + +以下模式会被屏蔽以防止注入攻击: + +- 命令链接:`;`、`&&`、`||`、`|` +- 命令替换:`` ` ``、`$()` +- 路径穿越:`../` +- 危险环境变量:`LD_PRELOAD`、`PATH`、`PYTHONPATH` + +### 示例:被屏蔽的攻击 + +```json +{ + "mcpServers": { + "malicious": { + "command": "bash", + "args": ["-c", "rm -rf /"] + } + } +} +``` + +此配置将被拒绝: +``` +ValueError: MCP server 'malicious': Command 'bash' is not in the allowed commands whitelist. +``` + +### 开发模式(不安全) + +仅限开发环境,可绕过安全校验: + +```json +{ + "mcpServers": { + "custom": { + "command": "my-custom-server", + "skip_security_validation": true + } + } +} +``` + +**警告**:切勿在生产环境中使用 `skip_security_validation`。 + +### 自定义白名单 + +如需通过编程方式向白名单添加自定义命令: + +```python +from vllm_mlx.mcp import MCPCommandValidator, set_validator + +# Add custom commands +validator = MCPCommandValidator( + custom_whitelist={"my-trusted-server", "another-server"} +) +set_validator(validator) +``` + +## 工具执行沙箱 + +除命令校验外,vllm-mlx 还为工具执行提供运行时沙箱。 + +### 沙箱功能 + +| 功能 | 说明 | +|---------|-------------| +| 工具白名单 | 仅允许特定工具执行 | +| 工具黑名单 | 屏蔽特定危险工具 | +| 参数校验 | 屏蔽工具参数中的危险模式 | +| 频率限制 | 限制每分钟的工具调用次数 | +| 审计日志 | 记录所有工具执行情况 | + +### 屏蔽的参数模式 + +工具参数会针对以下危险模式进行校验: + +- 路径穿越:`../` +- 系统目录:`/etc/`、`/proc/`、`/sys/` +- root 访问:`/root/`、`~root` + +### 高风险工具检测 + +匹配以下模式的工具会触发安全警告: + +- `execute`、`run_command`、`shell`、`eval`、`exec`、`system`、`subprocess` + +### 自定义沙箱配置 + +```python +from vllm_mlx.mcp import ToolSandbox, set_sandbox + +# Create sandbox with custom settings +sandbox = ToolSandbox( + # Only allow specific tools (whitelist mode) + allowed_tools={"read_file", "list_directory"}, + + # Block specific tools (blacklist mode) + blocked_tools={"execute_command", "run_shell"}, + + # Rate limit: max 30 calls per minute + max_calls_per_minute=30, + + # Optional audit callback + audit_callback=lambda audit: print(f"Tool: {audit.tool_name}, Success: {audit.success}"), +) +set_sandbox(sandbox) +``` + +### 访问审计日志 + +```python +from vllm_mlx.mcp import get_sandbox + +sandbox = get_sandbox() + +# Get recent audit entries +entries = sandbox.get_audit_log(limit=50) + +# Filter by tool name +file_ops = sandbox.get_audit_log(tool_filter="file") + +# Get only errors +errors = sandbox.get_audit_log(errors_only=True) + +# Clear audit log +sandbox.clear_audit_log() +``` + +### 敏感数据脱敏 + +审计日志会自动对敏感字段(password、token、secret、key、credential、auth)进行脱敏处理,并对过大的值进行截断。 + +## 故障排查 + +### MCP 服务器无法连接 + +检查 MCP 服务器命令是否正确: +```bash +npx -y @modelcontextprotocol/server-filesystem /tmp +``` + +### 工具无法执行 + +验证工具是否可用: +```bash +curl http://localhost:8000/v1/mcp/tools | jq '.tools[].name' +``` + +### tool call 未被解析 + +请确保所用模型支持函数调用(如 Qwen3、Llama-3.2-Instruct)。 + +### 命令不在白名单中 + +如果看到 "Command X is not in the allowed commands whitelist",可采取以下措施之一: +1. 使用允许的命令(参见上方白名单) +2. 将该命令添加到自定义白名单 +3. 使用 `skip_security_validation: true`(仅限开发环境) diff --git a/docs/zh/guides/moe-top-k.md b/docs/zh/guides/moe-top-k.md new file mode 100644 index 000000000..6ee1a85aa --- /dev/null +++ b/docs/zh/guides/moe-top-k.md @@ -0,0 +1,94 @@ +# MoE top_k 覆盖参数(`--moe-top-k`) + +减少 Mixture of Experts 模型(如 Qwen3-30B-A3B)每个 token 激活的 expert 数量,以少量质量损失换取明显更高的解码吞吐量。 + +> **状态:** 可选参数,默认行为不变。以下质量数据基于 Qwen3-30B-A3B-4bit 在 M4 Max 128 GB 上的测试结果,在将其用于生产环境前请在你的模型上自行验证。 + +## 功能说明 + +Qwen3-30B-A3B 使用 `top_k=8` 训练,即每个 token 从 128 个 expert 中选取 8 个。在 Apple Silicon 上进行 batch=1 解码时,expert 矩阵乘法(`SwitchGLU`)是每层计算中占比最大的部分,其开销与 `top_k` 大致呈线性关系。在推理阶段降低 `top_k` 已被证明(LExI 2025,Lynx 2024)能在保留大部分训练质量的同时,有效缩短解码时间。 + +`--moe-top-k N` 会遍历已加载模型的每一层,对含有 `.mlp.switch_mlp`(即稀疏 MoE 块)的层将 `top_k` 设置为 N。密集层和密集模型不受影响,该参数对它们是空操作。 + +## 用法 + +```bash +# Server +vllm-mlx serve mlx-community/Qwen3-30B-A3B-4bit \ + --continuous-batching \ + --moe-top-k 4 + +# Bench +vllm-mlx bench mlx-community/Qwen3-30B-A3B-4bit --moe-top-k 4 +``` + +若 N 大于模型训练时的 `top_k`,该参数会被拒绝,因为只有降低才有意义,不支持提高。 + +## 实测影响 + +### 解码吞吐量(M4 Max 128 GB,batch=1,贪心解码) + +| top_k | tok/s | 对比基线 | +|---:|---:|---:| +| 8(基线) | 126.5 | - | +| 6 | 136.1 | +7.6% | +| 5 | 140.3 | +10.9% | +| 4 | 147.3 | +16.5% | + +### 质量评估(Qwen3-30B-A3B-4bit,lm-evaluation-harness,MLX backend) + + + +| top_k | MMLU (acc) | GSM8K (exact match) | Δ vs baseline | +|---:|---:|---:|---:| +| 8 | TBD | TBD | - | +| 6 | TBD | TBD | TBD | +| 5 | TBD | TBD | TBD | +| 4 | TBD | TBD | TBD | + +MMLU:随机抽取 200 个样本,0-shot。 +GSM8K:随机抽取 100 个样本,0-shot,严格 exact-match。 + +以上数据具有**方向性参考价值**,完整评测集规模更大,会改变绝对精度数值,但各配置间的相对差距不会有太大变化。 + +### 贪心输出一致性 + +在 4-bit 检查点上使用 `top_k=4` 时,我们测试的所有探针提示中,生成的**前 16 个 token 与基线完全一致**。这表明 top_k=4 不会改变早期解码步骤中的 argmax,模型对减少一半激活 expert 具有内在的鲁棒性。 + +当 `top_k=3` 或更低时,质量会出现可见的下降(此处未测量,基于 LExI 论文推断),因此该参数在配置校验层刻意不允许低于 1,但生产环境推荐的最低值为 `top_k=4`。 + +## 适用场景与不适用场景 + +适合使用的情况: +- 运行 Qwen3 MoE(或兼容模型:Qwen3.5 MoE、Gemma-MoE),且单用户解码吞吐量是瓶颈。 +- 工作负载允许少量质量损失,以换取明显的延迟改善。 +- 部署在受内存带宽限制的硬件上(M 系列 Apple Silicon),expert gather 主导每步解码时间。 + +不适合使用的情况: +- 运行密集模型,该参数是空操作,没有任何效果。 +- 对评测集排行榜精度有顶尖要求。 +- 运行长链式推理或"思考模式"生成,质量下降幅度可能比 0-shot MMLU 所示更陡。 + +## 与其他优化叠加使用 + +该参数可与量化叠加使用。在 Qwen3-30B-A3B-4bit 上的实测叠加结果如下: + +- 4-bit + top_k=8:126.5 tok/s(基线) +- 4-bit + top_k=4:147.3 tok/s(+16.5%) +- 3-bit + top_k=8:138.6 tok/s(+9.6%) +- 3-bit + top_k=6:147.1 tok/s(+16.3%),质量差异可测量 +- 3-bit + top_k=4:157.3 tok/s(+24%),**输出质量严重下降**(在冒烟测试中模型回答了不同的问题) + +3-bit + top_k=4 的数值误差累积超出了 argmax 稳定的临界点。最多只应使用一个激进参数:4-bit + top_k=4 或 3-bit + top_k=6。两者的 tok/s 大致相同(约 147),但质量表现差异显著。 + +## 内部实现 + +- 补丁辅助函数:`vllm_mlx.scheduler.apply_moe_top_k_override(model, k)` +- 在 `Scheduler.__init__` 中于模型加载完成后执行。 +- 测试:`tests/test_moe_top_k.py`,覆盖密集模型、混合架构及校验路径。 + +## 参考资料 + +- LExI: Layer-Adaptive Active Experts, [arXiv 2509.02753](https://arxiv.org/html/2509.02753) +- Not All Experts are Equal (NAEE), [ACL 2024](https://aclanthology.org/2024.acl-long.334.pdf) +- SwiftLM (`SWIFTLM_TOP_K` env knob prior art), [github.com/SharpAI/SwiftLM](https://github.com/SharpAI/SwiftLM) diff --git a/docs/zh/guides/multimodal.md b/docs/zh/guides/multimodal.md new file mode 100644 index 000000000..71a673bf4 --- /dev/null +++ b/docs/zh/guides/multimodal.md @@ -0,0 +1,315 @@ +# 多模态模型(图像与视频) + +vllm-mlx 支持用于图像和视频理解的视觉语言模型。 + +## 支持的模型 + +- Qwen3-VL(推荐) +- Qwen2-VL +- Gemma 3 +- LLaVA +- Idefics +- PaliGemma +- Pixtral +- Molmo +- DeepSeek-VL + +## 启动多模态服务器 + +```bash +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit --port 8000 +``` + +名称中包含 "VL"、"Vision" 或 "mllm" 的模型会被自动识别为多模态模型。 + +## 图像分析 + +### 通过 OpenAI SDK + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# Image from URL +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}} + ] + }], + max_tokens=256 +) +print(response.choices[0].message.content) +``` + +### Base64 图像 + +```python +import base64 + +def encode_image(path): + with open(path, "rb") as f: + return base64.b64encode(f.read()).decode("utf-8") + +base64_image = encode_image("photo.jpg") +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} + ] + }] +) +``` + +### 通过 curl + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}} + ] + }], + "max_tokens": 256 + }' +``` + +## 视频分析 + +### 通过 OpenAI SDK + +```python +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What happens in this video?"}, + {"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}} + ] + }], + max_tokens=512 +) +``` + +### 视频参数 + +通过额外的请求体参数控制帧提取: + +```python +response = client.chat.completions.create( + model="default", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "Describe this video"}, + {"type": "video_url", "video_url": {"url": "video.mp4"}} + ] + }], + extra_body={ + "video_fps": 2.0, + "video_max_frames": 32 + } +) +``` + +### 通过 curl + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "Describe this video"}, + {"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}} + ] + }], + "video_fps": 2.0, + "video_max_frames": 16 + }' +``` + +## 支持的格式 + +### 图像 + +| 格式 | 示例 | +|------|------| +| URL | `{"type": "image_url", "image_url": {"url": "https://..."}}` | +| 本地文件 | `{"type": "image_url", "image_url": {"url": "/path/to/image.jpg"}}` | +| Base64 | `{"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}}` | + +### 视频 + +| 格式 | 示例 | +|------|------| +| URL | `{"type": "video_url", "video_url": {"url": "https://..."}}` | +| 本地文件 | `{"type": "video", "video": "/path/to/video.mp4"}` | +| Base64 | `{"type": "video_url", "video_url": {"url": "data:video/mp4;base64,..."}}` | + +## Python API + +```python +from vllm_mlx.models import MLXMultimodalLM + +mllm = MLXMultimodalLM("mlx-community/Qwen3-VL-4B-Instruct-3bit") +mllm.load() + +# Image +description = mllm.describe_image("photo.jpg") + +# Video +description = mllm.describe_video("video.mp4", fps=2.0) + +# Custom prompt +output = mllm.generate( + prompt="Compare these images", + images=["img1.jpg", "img2.jpg"] +) +``` + +## 性能建议 + +### 图像 +- 分辨率越小,处理速度越快(224x224 对比 1920x1080) +- 根据任务选择合适的分辨率 + +### 视频 +- 帧率越低,处理速度越快 +- 帧数越少,内存占用越低 +- 64 帧是实际可用的最大值(96 帧及以上会导致 GPU 超时) + +## 基准测试 + +在配备 128 GB 统一内存的 Apple M4 Max 上测试。 + +### Qwen3-VL-4B-Instruct-3bit + +| 分辨率 | 耗时 | token 数 | 速度 | 内存 | +|--------|------|----------|------|------| +| 224x224 | 0.87s | 124 | 143 tok/s | 2.6 GB | +| 448x448 | 1.01s | 107 | 106 tok/s | 3.1 GB | +| 768x768 | 1.42s | 127 | 89 tok/s | 3.4 GB | +| 1024x1024 | 1.85s | 116 | 63 tok/s | 3.6 GB | + +### Qwen3-VL-8B-Instruct-4bit + +| 分辨率 | 耗时 | token 数 | 速度 | 内存 | +|--------|------|----------|------|------| +| 224x224 | 1.08s | 78 | 73 tok/s | 5.6 GB | +| 448x448 | 1.41s | 70 | 50 tok/s | 6.1 GB | +| 768x768 | 2.06s | 91 | 44 tok/s | 6.5 GB | +| 1024x1024 | 3.02s | 76 | 25 tok/s | 7.6 GB | + +### Gemma 3 4B 4bit + +| 分辨率 | 耗时 | token 数 | 速度 | 内存 | +|--------|------|----------|------|------| +| 224x224 | 0.95s | 30 | 32 tok/s | 5.2 GB | +| 448x448 | 0.99s | 34 | 34 tok/s | 5.2 GB | +| 768x768 | 0.99s | 32 | 32 tok/s | 5.2 GB | +| 1024x1024 | 0.95s | 28 | 29 tok/s | 5.2 GB | + +### 运行基准测试 + +```bash +# Quick benchmark +vllm-mlx-bench --model mlx-community/Qwen3-VL-4B-Instruct-3bit --quick + +# Full benchmark with more resolutions +vllm-mlx-bench --model mlx-community/Qwen3-VL-4B-Instruct-3bit + +# Video benchmark +vllm-mlx-bench --model mlx-community/Qwen3-VL-4B-Instruct-3bit --video +``` + +## MLLM Cache + +vllm-mlx 为多模态模型内置了 prefix cache 系统,可以显著加速对相同图像的重复请求。 + +### 工作原理 + +向模型发送图像时,视觉编码器会将其处理为嵌入向量,该过程通常需要 1 到 2 秒。MLLM Cache 会同时存储这些嵌入向量和 KV cache 状态,因此后续使用相同图像的请求可以完全跳过视觉编码器。 + +该缓存采用基于内容的哈希(类似 LMCache)来识别相同图像,无论图像以何种方式提供(URL、base64 还是文件路径)。 + +### 启用缓存 + +```bash +# Enable with default settings (512 MB max) +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit --enable-mllm-cache + +# With custom memory limit +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit \ + --enable-mllm-cache \ + --mllm-cache-max-mb 1024 +``` + +### Python API + +```python +from vllm_mlx.mllm_cache import MLLMPrefixCacheManager + +# Create cache manager +cache = MLLMPrefixCacheManager(max_memory_mb=512) + +# Store embeddings and KV cache after processing +cache.store( + images=["photo.jpg"], + prompt="Describe this image", + vision_embeddings=embeddings, + kv_cache=kv_state, + num_tokens=128 +) + +# Fetch from cache on subsequent requests +entry, match_len = cache.fetch(images=["photo.jpg"], prompt="Describe this image") +if entry: + # Use cached embeddings, skip vision encoder + embeddings = entry.vision_embeddings + kv_state = entry.kv_cache +``` + +### 缓存统计 + +```python +stats = cache.get_stats() +print(f"Hit rate: {stats.hit_rate:.1%}") +print(f"Memory used: {stats.memory_used_mb:.1f} MB") +print(f"Tokens saved: {stats.tokens_saved}") +``` + +### 内存管理 + +当达到内存上限时,缓存采用 LRU(最近最少使用)策略进行淘汰。每个缓存条目记录以下信息: + +- 视觉嵌入向量大小 +- 每层 KV cache 大小 +- 用于 LRU 排序的访问频率 + +当内存压力出现时,最近最少访问的条目会被优先淘汰。 + +## Gradio 聊天界面 + +如需交互式多模态对话: + +```bash +vllm-mlx-chat --served-model-name mlx-community/Qwen3-VL-4B-Instruct-3bit +``` + +支持拖放图像和视频。 diff --git a/docs/zh/guides/python-api.md b/docs/zh/guides/python-api.md new file mode 100644 index 000000000..0e1f2dbaa --- /dev/null +++ b/docs/zh/guides/python-api.md @@ -0,0 +1,182 @@ +# Python API + +通过 Python API 直接以编程方式访问 vllm-mlx。 + +## Language Models + +### 基本用法 + +```python +from vllm_mlx.models import MLXLanguageModel + +# Load model +model = MLXLanguageModel("mlx-community/Llama-3.2-3B-Instruct-4bit") +model.load() + +# Generate text +output = model.generate("What is the capital of France?", max_tokens=100) +print(output.text) +``` + +### Streaming Generation + +```python +for chunk in model.stream_generate("Tell me a story about a robot"): + print(chunk.text, end="", flush=True) +``` + +### Chat 接口 + +```python +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, who are you?"} +] +response = model.chat(messages) +print(response.text) +``` + +### 生成参数 + +```python +output = model.generate( + prompt="Write a poem", + max_tokens=256, + temperature=0.7, + top_p=0.9, + stop=["END", "\n\n"] +) +``` + +| 参数 | 描述 | 默认值 | +|-----------|-------------|---------| +| `max_tokens` | 最大生成 token 数量 | 256 | +| `temperature` | 采样温度(0-2) | 0.7 | +| `top_p` | Nucleus sampling | 0.9 | +| `stop` | 停止序列 | None | + +## Vision-Language Models + +### 基本用法 + +```python +from vllm_mlx.models import MLXMultimodalLM + +# Load model +mllm = MLXMultimodalLM("mlx-community/Qwen3-VL-4B-Instruct-3bit") +mllm.load() + +# Describe an image +description = mllm.describe_image("photo.jpg") +print(description) +``` + +### 问答 + +```python +answer = mllm.answer_about_image("photo.jpg", "What color is the car?") +print(answer) +``` + +### 多图片输入 + +```python +output = mllm.generate( + prompt="Compare these two images", + images=["image1.jpg", "image2.jpg"] +) +print(output.text) +``` + +### 视频理解 + +```python +# From local file +output = mllm.generate( + prompt="What is happening in this video?", + videos=["video.mp4"], + video_fps=2.0, + video_max_frames=16 +) +print(output.text) + +# From URL +output = mllm.generate( + prompt="Describe this video", + videos=["https://example.com/video.mp4"], + video_fps=2.0 +) + +# Convenience method +description = mllm.describe_video("video.mp4", fps=2.0) +``` + +### 视频参数 + +| 参数 | 描述 | 默认值 | +|-----------|-------------|---------| +| `video_fps` | 每秒提取帧数 | 2.0 | +| `video_max_frames` | 最大处理帧数 | 32 | + +## Engine API + +针对高级使用场景,可直接使用 engine: + +### Simple Engine + +```python +from vllm_mlx.engine import SimpleEngine + +engine = SimpleEngine("mlx-community/Llama-3.2-3B-Instruct-4bit") +await engine.start() + +output = await engine.generate( + prompt="Hello world", + max_tokens=100 +) +print(output.text) + +await engine.stop() +``` + +### Batched Engine + +```python +from vllm_mlx.engine import BatchedEngine + +engine = BatchedEngine("mlx-community/Llama-3.2-3B-Instruct-4bit") +await engine.start() + +# Multiple concurrent requests +output = await engine.generate( + prompt="Hello world", + max_tokens=100 +) + +await engine.stop() +``` + +## 输出格式 + +所有生成方法均返回 `GenerationOutput`: + +```python +output = model.generate("Hello") + +print(output.text) # Generated text +print(output.prompt_tokens) # Input token count +print(output.completion_tokens) # Output token count +print(output.finish_reason) # "stop" or "length" +``` + +## 错误处理 + +```python +from vllm_mlx.models import MLXLanguageModel + +try: + model = MLXLanguageModel("invalid-model") + model.load() +except Exception as e: + print(f"Failed to load model: {e}") +``` diff --git a/docs/zh/guides/reasoning.md b/docs/zh/guides/reasoning.md new file mode 100644 index 000000000..3aacaac71 --- /dev/null +++ b/docs/zh/guides/reasoning.md @@ -0,0 +1,263 @@ +# Reasoning 模型 + +vllm-mlx 支持在给出答案之前展示 thinking 过程的 reasoning 模型。Qwen3 和 DeepSeek-R1 等模型会将 reasoning 内容包裹在 `...` 标签中,vllm-mlx 可以解析这些标签,将 reasoning 与最终回答分离。 + +## 为什么使用 Reasoning 解析? + +reasoning 模型生成的原始输出通常如下所示: + +``` + +Let me analyze this step by step. +First, I need to consider the constraints. +The answer should be a prime number less than 10. +Checking: 2, 3, 5, 7 are all prime and less than 10. + +The prime numbers less than 10 are: 2, 3, 5, 7. +``` + +不启用 reasoning 解析时,响应中会包含原始标签。启用 reasoning parsing 后,thinking 过程与最终回答会被分离到 API 响应的不同字段中。 + +## 快速开始 + +### 启动服务器并指定 Reasoning Parser + +```bash +# For Qwen3 models +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 + +# For DeepSeek-R1 models +vllm-mlx serve mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit --reasoning-parser deepseek_r1 +``` + +### API 响应格式 + +启用 reasoning parsing 后,API 响应中会包含 `reasoning` 字段。 + +**非 streaming 响应:** + +```json +{ + "choices": [{ + "message": { + "role": "assistant", + "content": "The prime numbers less than 10 are: 2, 3, 5, 7.", + "reasoning": "Let me analyze this step by step.\nFirst, I need to consider the constraints.\nThe answer should be a prime number less than 10.\nChecking: 2, 3, 5, 7 are all prime and less than 10." + } + }] +} +``` + +**Streaming 响应:** + +reasoning 和正文内容分块独立发送。在 reasoning 阶段,数据块的 `reasoning` 字段有内容;当模型进入最终回答阶段后,数据块的 `content` 字段有内容: + +```json +{"delta": {"reasoning": "Let me analyze"}} +{"delta": {"reasoning": " this step by step."}} +{"delta": {"reasoning": "\nFirst, I need to"}} +... +{"delta": {"content": "The prime"}} +{"delta": {"content": " numbers less than 10"}} +{"delta": {"content": " are: 2, 3, 5, 7."}} +``` + +## 与 OpenAI SDK 配合使用 + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# Non-streaming +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What are the prime numbers less than 10?"}] +) + +message = response.choices[0].message +print("Reasoning:", message.reasoning) # The thinking process +print("Answer:", message.content) # The final answer +``` + +### Streaming 与 Reasoning + +```python +reasoning_text = "" +content_text = "" + +stream = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Solve: 2 + 2 = ?"}], + stream=True +) + +for chunk in stream: + delta = chunk.choices[0].delta + if hasattr(delta, 'reasoning') and delta.reasoning: + reasoning_text += delta.reasoning + print(f"[Thinking] {delta.reasoning}", end="") + if delta.content: + content_text += delta.content + print(delta.content, end="") + +print(f"\n\nFinal reasoning: {reasoning_text}") +print(f"Final answer: {content_text}") +``` + +## 支持的 Parser + +### Qwen3 Parser(`qwen3`) + +适用于使用显式 `` 和 `` 标签的 Qwen3 模型。 + +- 需要开标签和闭标签**同时存在** +- 如果标签缺失,输出将被视为普通内容 +- 适合:Qwen3-0.6B、Qwen3-4B、Qwen3-8B 及同系列模型 + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 +``` + +### DeepSeek-R1 Parser(`deepseek_r1`) + +适用于可能省略开标签 `` 的 DeepSeek-R1 模型。 + +- 比 Qwen3 parser 更宽松 +- 能处理 `` 为隐式的情况 +- 即使没有 ``,`` 之前的内容也会被视为 reasoning + +```bash +vllm-mlx serve mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit --reasoning-parser deepseek_r1 +``` + +## 工作原理 + +reasoning parser 通过基于文本的检测来识别模型输出中的 thinking 标签。在 streaming 过程中,它会追踪当前在输出中的位置,将每个 token 正确路由到 `reasoning` 或 `content` 字段。 + +``` +Model Output: Step 1: analyze...The answer is 42. + ├─────────────────────┤├─────────────────────┤ +Parsed: │ reasoning ││ content │ + └─────────────────────┘└─────────────────────┘ +``` + +解析过程是无状态的,通过累积文本来判断上下文,在 token 以任意分块到达的 streaming 场景下也能稳定工作。 + +## 最佳使用建议 + +### 提示词写法 + +引导模型逐步思考,reasoning 模型的效果更好: + +```python +messages = [ + {"role": "system", "content": "Think through problems step by step before answering."}, + {"role": "user", "content": "What is 17 × 23?"} +] +``` + +### 处理缺失的 Reasoning + +某些提示词可能不会触发 reasoning。此时 `reasoning` 值为 `None`,所有输出都进入 `content`: + +```python +message = response.choices[0].message +if message.reasoning: + print(f"Model's thought process: {message.reasoning}") +print(f"Answer: {message.content}") +``` + +### 温度参数与 Reasoning + +较低的温度通常会产生更稳定的 reasoning 模式: + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Explain quantum entanglement"}], + temperature=0.3 # More focused reasoning +) +``` + +## 向后兼容性 + +未指定 `--reasoning-parser` 时,服务器行为与之前一致:thinking 标签包含在 `content` 字段中,响应中不会添加 `reasoning` 字段。这确保现有应用无需修改即可继续正常使用。 + +## 示例:数学题求解器 + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +def solve_math(problem: str) -> dict: + """Solve a math problem and return reasoning + answer.""" + response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a math tutor. Show your work."}, + {"role": "user", "content": problem} + ], + temperature=0.2 + ) + + message = response.choices[0].message + return { + "problem": problem, + "work": message.reasoning, + "answer": message.content + } + +result = solve_math("If a train travels 120 km in 2 hours, what is its average speed?") +print(f"Problem: {result['problem']}") +print(f"\nWork shown:\n{result['work']}") +print(f"\nFinal answer: {result['answer']}") +``` + +## Curl 示例 + +### 非 Streaming + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "What is 15% of 80?"}] + }' +``` + +### Streaming + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "What is 15% of 80?"}], + "stream": true + }' +``` + +## 常见问题排查 + +### 响应中没有 reasoning 字段 + +- 确认启动服务器时指定了 `--reasoning-parser` +- 检查模型是否实际使用了 thinking 标签(并非所有提示词都会触发 reasoning) + +### Reasoning 出现在 content 中 + +- 模型可能没有使用预期的标签格式 +- 尝试换用其他 parser(`qwen3` 或 `deepseek_r1`) + +### Reasoning 被截断 + +- 如果模型在 thinking 过程中触及了 token 上限,请增大 `--max-tokens` + +## 相关链接 + +- [支持的模型](../reference/models.md). 支持 reasoning 的模型列表 +- [服务器配置](server.md). 所有服务器选项 +- [CLI 参考](../reference/cli.md). 命令行选项 diff --git a/docs/zh/guides/server.md b/docs/zh/guides/server.md new file mode 100644 index 000000000..d099d5212 --- /dev/null +++ b/docs/zh/guides/server.md @@ -0,0 +1,780 @@ +# OpenAI 兼容服务器 + +vllm-mlx 提供一个具备完整 OpenAI API 兼容性的 FastAPI 服务器。 + +默认情况下,服务器仅绑定到 `127.0.0.1`。只有在明确需要将其暴露到本机以外的网络时,才使用 `--host 0.0.0.0`。 + +## 启动服务器 + +### 简单模式(默认) + +单用户最大吞吐量: + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 +``` + +### continuous batching 模式 + +适用于多个并发用户: + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous-batching +``` + +### 启用 paged cache + +适用于生产环境的高效内存缓存: + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 --continuous-batching --use-paged-cache +``` + +## 服务器选项 + +| 选项 | 说明 | 默认值 | +|------|------|--------| +| `--port` | 服务器端口 | 8000 | +| `--host` | 服务器主机 | 127.0.0.1 | +| `--api-key` | 身份验证用 API key | None | +| `--rate-limit` | 每客户端每分钟请求数(0 表示禁用) | 0 | +| `--timeout` | 请求超时时间(秒) | 300 | +| `--enable-metrics` | 在 `/metrics` 上暴露 Prometheus 指标 | False | +| `--continuous-batching` | 为多用户启用 batching | False | +| `--use-paged-cache` | 启用 paged KV cache | False | +| `--cache-memory-mb` | 缓存内存上限(MB) | Auto | +| `--cache-memory-percent` | 用于缓存的 RAM 比例 | 0.20 | +| `--max-tokens` | 默认最大 token 数 | 32768 | +| `--max-request-tokens` | API 客户端可传入的 `max_tokens` 最大值 | 32768 | +| `--default-temperature` | 未指定时的默认 temperature | None | +| `--default-top-p` | 未指定时的默认 top_p | None | +| `--stream-interval` | 每个 streaming chunk 包含的 token 数 | 1 | +| `--mcp-config` | MCP 配置文件路径 | None | +| `--reasoning-parser` | reasoning 模型解析器(`qwen3`、`deepseek_r1`) | None | +| `--embedding-model` | 启动时预加载 embeddings 模型 | None | +| `--enable-auto-tool-choice` | 启用自动 tool calling | False | +| `--tool-call-parser` | tool call 解析器(参见 [Tool Calling](tool-calling.md)) | None | + +## API 端点 + +### Chat Completions + +```bash +POST /v1/chat/completions +``` + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +# Non-streaming +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Hello!"}], + max_tokens=100 +) + +# Streaming +stream = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Tell me a story"}], + stream=True +) +for chunk in stream: + if chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="") +``` + +### Completions + +```bash +POST /v1/completions +``` + +```python +response = client.completions.create( + model="default", + prompt="The capital of France is", + max_tokens=50 +) +``` + +### Models + +```bash +GET /v1/models +``` + +返回可用模型列表。 + +### Embeddings + +```bash +POST /v1/embeddings +``` + +```python +response = client.embeddings.create( + model="mlx-community/multilingual-e5-small-mlx", + input="Hello world" +) +print(response.data[0].embedding[:5]) # First 5 dimensions +``` + +详情参见 [Embeddings 指南](embeddings.md)。 + +### 健康检查 + +```bash +GET /health +``` + +返回服务器状态。 + +### 指标 + +```bash +GET /metrics +``` + +Prometheus 抓取端点,提供服务器、缓存、scheduler 及请求指标。该端点默认禁用,需通过 `--enable-metrics` 启用。 + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit \ + --enable-metrics +``` + +`/metrics` 端点有意不做身份验证。请仅在受信任的网络中暴露,或通过反向代理、防火墙限制访问来源。 + +### Anthropic Messages API + +```bash +POST /v1/messages +``` + +兼容 Anthropic 协议的端点,允许 Claude Code、OpenCode 等工具直接连接 vllm-mlx。内部将 Anthropic 请求转换为 OpenAI 格式,经引擎推理后再将响应转换回 Anthropic 格式。 + +功能: +- 非 streaming 与 streaming 响应(SSE) +- 系统消息(纯字符串或内容块列表) +- 包含用户和助手消息的多轮对话 +- 使用 `tool_use` / `tool_result` 内容块进行 tool calling +- 用于预算追踪的 token 计数 +- 多模态内容(通过 `source` 块传入图片) +- 客户端断开检测(返回 HTTP 499) +- streaming 输出中的特殊 token 自动过滤 + +#### 非 streaming + +```python +from anthropic import Anthropic + +client = Anthropic(base_url="http://localhost:8000", api_key="not-needed") + +response = client.messages.create( + model="default", + max_tokens=256, + messages=[{"role": "user", "content": "Hello!"}] +) +print(response.content[0].text) +# Response includes: response.id, response.model, response.stop_reason, +# response.usage.input_tokens, response.usage.output_tokens +``` + +#### Streaming + +streaming 遵循 Anthropic SSE 事件协议,事件按以下顺序发出: +`message_start` -> `content_block_start` -> `content_block_delta`(重复)-> `content_block_stop` -> `message_delta` -> `message_stop` + +```python +with client.messages.stream( + model="default", + max_tokens=256, + messages=[{"role": "user", "content": "Tell me a story"}] +) as stream: + for text in stream.text_stream: + print(text, end="") +``` + +#### 系统消息 + +系统消息可以是纯字符串,也可以是内容块列表: + +```python +# Plain string +response = client.messages.create( + model="default", + max_tokens=256, + system="You are a helpful coding assistant.", + messages=[{"role": "user", "content": "Write a hello world in Python"}] +) + +# List of content blocks +response = client.messages.create( + model="default", + max_tokens=256, + system=[ + {"type": "text", "text": "You are a helpful assistant."}, + {"type": "text", "text": "Be concise in your answers."}, + ], + messages=[{"role": "user", "content": "What is 2+2?"}] +) +``` + +#### Tool calling + +使用 `name`、`description` 和 `input_schema` 定义工具。模型在需要调用工具时会返回 `tool_use` 内容块。将结果以 `tool_result` 块的形式返回。 + +```python +# Step 1: Send request with tools +response = client.messages.create( + model="default", + max_tokens=1024, + messages=[{"role": "user", "content": "What's the weather in Paris?"}], + tools=[{ + "name": "get_weather", + "description": "Get weather for a city", + "input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + } + }] +) + +# Step 2: Check if model wants to use tools +for block in response.content: + if block.type == "tool_use": + print(f"Tool: {block.name}, Input: {block.input}, ID: {block.id}") + # response.stop_reason will be "tool_use" + +# Step 3: Send tool result back +response = client.messages.create( + model="default", + max_tokens=1024, + messages=[ + {"role": "user", "content": "What's the weather in Paris?"}, + {"role": "assistant", "content": response.content}, + {"role": "user", "content": [ + { + "type": "tool_result", + "tool_use_id": block.id, + "content": "Sunny, 22C" + } + ]} + ], + tools=[{ + "name": "get_weather", + "description": "Get weather for a city", + "input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + } + }] +) +print(response.content[0].text) # "The weather in Paris is sunny, 22C." +``` + +Tool choice 模式: + +| `tool_choice` | 行为 | +|---------------|------| +| `{"type": "auto"}` | 由模型决定是否调用工具(默认) | +| `{"type": "any"}` | 模型必须至少调用一个工具 | +| `{"type": "tool", "name": "get_weather"}` | 模型必须调用指定工具 | +| `{"type": "none"}` | 模型不调用任何工具 | + +#### 多轮对话 + +```python +messages = [ + {"role": "user", "content": "My name is Alice."}, + {"role": "assistant", "content": "Nice to meet you, Alice!"}, + {"role": "user", "content": "What's my name?"}, +] + +response = client.messages.create( + model="default", + max_tokens=100, + messages=messages +) +``` + +#### Token 计数 + +```bash +POST /v1/messages/count_tokens +``` + +使用模型的 tokenizer 统计 Anthropic 请求的输入 token 数。适用于在发送请求前进行预算追踪。可统计系统消息、对话消息、tool_use 输入、tool_result 内容及工具定义(name、description、input_schema)中的 token。 + +```python +import requests + +resp = requests.post("http://localhost:8000/v1/messages/count_tokens", json={ + "model": "default", + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "system": "You are helpful.", + "tools": [{ + "name": "search", + "description": "Search the web", + "input_schema": {"type": "object", "properties": {"q": {"type": "string"}}} + }] +}) +print(resp.json()) # {"input_tokens": 42} +``` + +#### curl 示例 + +非 streaming: + +```bash +curl http://localhost:8000/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + +Streaming: + +```bash +curl http://localhost:8000/v1/messages \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "max_tokens": 256, + "stream": true, + "messages": [{"role": "user", "content": "Tell me a joke"}] + }' +``` + +Token 计数: + +```bash +curl http://localhost:8000/v1/messages/count_tokens \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "Hello!"}] + }' +# {"input_tokens": 12} +``` + +#### 请求字段 + +| 字段 | 类型 | 是否必填 | 默认值 | 说明 | +|------|------|----------|--------|------| +| `model` | string | 是 | - | 模型名称(使用 `"default"` 指向已加载的模型) | +| `messages` | list | 是 | - | 包含 `role` 和 `content` 的对话消息 | +| `max_tokens` | int | 是 | - | 最大生成 token 数 | +| `system` | string 或 list | 否 | null | 系统提示(字符串或 `{"type": "text", "text": "..."}` 块列表) | +| `stream` | bool | 否 | false | 启用 SSE streaming | +| `temperature` | float | 否 | 0.7 | 采样 temperature(0.0 为确定性,1.0 为创意性) | +| `top_p` | float | 否 | 0.9 | nucleus sampling 阈值 | +| `top_k` | int | 否 | null | top-k sampling | +| `stop_sequences` | list | 否 | null | 触发停止生成的序列 | +| `tools` | list | 否 | null | 包含 `name`、`description`、`input_schema` 的工具定义 | +| `tool_choice` | dict | 否 | null | 工具选择模式(`auto`、`any`、`tool`、`none`) | +| `metadata` | dict | 否 | null | 任意元数据(透传,服务器不使用) | + +#### 响应格式 + +非 streaming 响应: + +```json +{ + "id": "msg_abc123...", + "type": "message", + "role": "assistant", + "model": "default", + "content": [ + {"type": "text", "text": "Hello! How can I help?"} + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 12, + "output_tokens": 8 + } +} +``` + +调用工具时,`content` 包含 `tool_use` 块,且 `stop_reason` 为 `"tool_use"`: + +```json +{ + "content": [ + {"type": "text", "text": "Let me check the weather."}, + { + "type": "tool_use", + "id": "call_abc123", + "name": "get_weather", + "input": {"city": "Paris"} + } + ], + "stop_reason": "tool_use" +} +``` + +停止原因: + +| `stop_reason` | 含义 | +|---------------|------| +| `end_turn` | 模型自然完成生成 | +| `tool_use` | 模型需要调用工具 | +| `max_tokens` | 达到 `max_tokens` 上限 | + +#### 与 Claude Code 配合使用 + +将 Claude Code 直接指向你的 vllm-mlx 服务器: + +```bash +# Start the server +vllm-mlx serve mlx-community/Qwen3-Coder-Next-235B-A22B-4bit \ + --continuous-batching \ + --enable-auto-tool-choice \ + --tool-call-parser hermes + +# In another terminal, configure Claude Code +export ANTHROPIC_BASE_URL=http://localhost:8000 +export ANTHROPIC_API_KEY=not-needed +claude +``` + +### 服务器状态 + +```bash +GET /v1/status +``` + +实时监控端点,返回服务器全局统计信息及每个请求的详情。适用于调试性能、追踪缓存效率以及监控 Metal GPU 内存。 + +```bash +curl -s http://localhost:8000/v1/status | python -m json.tool +``` + +示例响应: + +```json +{ + "status": "running", + "model": "mlx-community/Qwen3-8B-4bit", + "uptime_s": 342.5, + "steps_executed": 1247, + "num_running": 1, + "num_waiting": 0, + "total_requests_processed": 15, + "total_prompt_tokens": 28450, + "total_completion_tokens": 3200, + "metal": { + "active_memory_gb": 5.2, + "peak_memory_gb": 8.1, + "cache_memory_gb": 2.3 + }, + "cache": { + "type": "memory_aware_cache", + "entries": 5, + "hit_rate": 0.87, + "memory_mb": 2350 + }, + "requests": [ + { + "request_id": "req_abc123", + "phase": "generation", + "tokens_per_second": 45.2, + "ttft_s": 0.8, + "progress": 0.35, + "cache_hit_type": "prefix", + "cached_tokens": 1200, + "generated_tokens": 85, + "max_tokens": 256 + } + ] +} +``` + +响应字段: + +| 字段 | 说明 | +|------|------| +| `status` | 服务器状态:`running`、`stopped` 或 `not_loaded` | +| `model` | 已加载模型的名称 | +| `uptime_s` | 服务器启动后经过的秒数 | +| `steps_executed` | 已执行的推理步骤总数 | +| `num_running` | 当前正在生成 token 的请求数 | +| `num_waiting` | 排队等待 prefill 的请求数 | +| `total_requests_processed` | 启动以来已完成的请求总数 | +| `total_prompt_tokens` | 启动以来处理的 prompt token 总数 | +| `total_completion_tokens` | 启动以来生成的 completion token 总数 | +| `metal.active_memory_gb` | 当前使用的 Metal GPU 内存(GB) | +| `metal.peak_memory_gb` | Metal GPU 内存峰值用量(GB) | +| `metal.cache_memory_gb` | Metal 缓存内存用量(GB) | +| `cache` | 缓存统计信息(类型、条目数、命中率、内存用量) | +| `requests` | 活跃请求列表,包含每个请求的详细信息 | + +`requests` 中的每请求字段: + +| 字段 | 说明 | +|------|------| +| `request_id` | 唯一请求标识符 | +| `phase` | 当前阶段:`queued`、`prefill` 或 `generation` | +| `tokens_per_second` | 该请求的生成吞吐量 | +| `ttft_s` | 首 token 时间(秒),即 TTFT | +| `progress` | 完成进度(0.0 到 1.0) | +| `cache_hit_type` | 缓存匹配类型:`exact`、`prefix`、`supersequence`、`lcp` 或 `miss` | +| `cached_tokens` | 从缓存中命中的 token 数 | +| `generated_tokens` | 已生成的 token 数 | +| `max_tokens` | 请求的最大 token 数 | + +## Tool Calling + +使用 `--enable-auto-tool-choice` 启用兼容 OpenAI 的 tool calling: + +```bash +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice \ + --tool-call-parser mistral +``` + +使用 `--tool-call-parser` 选项为你的模型选择对应的解析器: + +| 解析器 | 适用模型 | +|--------|----------| +| `auto` | 自动检测(依次尝试所有解析器) | +| `mistral` | Mistral、Devstral | +| `qwen` | Qwen、Qwen3 | +| `llama` | Llama 3.x、4.x | +| `hermes` | Hermes、NousResearch | +| `deepseek` | DeepSeek V3、R1 | +| `kimi` | Kimi K2、Moonshot | +| `granite` | IBM Granite 3.x、4.x | +| `nemotron` | NVIDIA Nemotron | +| `xlam` | Salesforce xLAM | +| `functionary` | MeetKai Functionary | +| `glm47` | GLM-4.7、GLM-4.7-Flash | + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What's the weather in Paris?"}], + tools=[{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + } + } + }] +) + +if response.choices[0].message.tool_calls: + for tc in response.choices[0].message.tool_calls: + print(f"{tc.function.name}: {tc.function.arguments}") +``` + +完整文档参见 [Tool Calling 指南](tool-calling.md)。 + +## Reasoning 模型 + +对于展示思考过程的模型(Qwen3、DeepSeek-R1),使用 `--reasoning-parser` 将 reasoning 内容与最终答案分离: + +```bash +# Qwen3 models +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 + +# DeepSeek-R1 models +vllm-mlx serve mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit --reasoning-parser deepseek_r1 +``` + +API 响应中包含 `reasoning` 字段,用于展示模型的思考过程: + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What is 17 × 23?"}] +) + +print(response.choices[0].message.reasoning) # Step-by-step thinking +print(response.choices[0].message.content) # Final answer +``` + +streaming 时,reasoning chunk 先于 content chunk 到达: + +```python +for chunk in stream: + delta = chunk.choices[0].delta + if delta.reasoning: + print(f"[Thinking] {delta.reasoning}") + if delta.content: + print(delta.content, end="") +``` + +完整说明参见 [Reasoning 模型指南](reasoning.md)。 + +## 结构化输出(JSON 模式) + +使用 `response_format` 强制模型返回合法 JSON。 + +### JSON Object 模式 + +返回任意合法 JSON: + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "List 3 colors"}], + response_format={"type": "json_object"} +) +# Output: {"colors": ["red", "blue", "green"]} +``` + +### JSON Schema 模式 + +返回符合指定 schema 的 JSON: + +```python +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "List 3 colors"}], + response_format={ + "type": "json_schema", + "json_schema": { + "name": "colors", + "schema": { + "type": "object", + "properties": { + "colors": { + "type": "array", + "items": {"type": "string"} + } + }, + "required": ["colors"] + } + } + } +) +# Output validated against schema +data = json.loads(response.choices[0].message.content) +assert "colors" in data +``` + +### curl 示例 + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "List 3 colors"}], + "response_format": {"type": "json_object"} + }' +``` + +## curl 示例 + +### Chat + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "Hello!"}], + "max_tokens": 100 + }' +``` + +### Streaming + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "default", + "messages": [{"role": "user", "content": "Hello!"}], + "stream": true + }' +``` + +## Streaming 配置 + +使用 `--stream-interval` 控制 streaming 行为: + +| 值 | 行为 | +|----|------| +| `1`(默认) | 每个 token 立即发送 | +| `2-5` | 积攒若干 token 后再发送 | +| `10+` | 最大吞吐量,输出分块较大 | + +```bash +# Smooth streaming +vllm-mlx serve model --continuous-batching --stream-interval 1 + +# Batched streaming (better for high-latency networks) +vllm-mlx serve model --continuous-batching --stream-interval 5 +``` + +## Open WebUI 集成 + +```bash +# 1. Start vllm-mlx server +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --port 8000 + +# 2. Start Open WebUI +docker run -d -p 3000:8080 \ + -e OPENAI_API_BASE_URL=http://host.docker.internal:8000/v1 \ + -e OPENAI_API_KEY=not-needed \ + --name open-webui \ + ghcr.io/open-webui/open-webui:main + +# 3. Open http://localhost:3000 +``` + +## 生产部署 + +### 使用 systemd + +创建 `/etc/systemd/system/vllm-mlx.service`: + +```ini +[Unit] +Description=vLLM-MLX Server +After=network.target + +[Service] +Type=simple +ExecStart=/usr/local/bin/vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching --use-paged-cache --port 8000 +Restart=always + +[Install] +WantedBy=multi-user.target +``` + +```bash +sudo systemctl enable vllm-mlx +sudo systemctl start vllm-mlx +``` + +### 推荐配置 + +适用于 50 个以上并发用户的生产环境: + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching \ + --use-paged-cache \ + --api-key your-secret-key \ + --rate-limit 60 \ + --timeout 120 \ + --port 8000 +``` diff --git a/docs/zh/guides/tool-calling.md b/docs/zh/guides/tool-calling.md new file mode 100644 index 000000000..0e6b69422 --- /dev/null +++ b/docs/zh/guides/tool-calling.md @@ -0,0 +1,245 @@ +# Tool Calling + +vllm-mlx 支持与 OpenAI 兼容的 tool calling(function calling),并为多种主流模型系列提供自动解析。 + +## 快速开始 + +启动服务器时添加 `--enable-auto-tool-choice` 标志即可启用 tool calling: + +```bash +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice \ + --tool-call-parser mistral +``` + +然后通过标准 OpenAI API 使用工具: + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed") + +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What's the weather in Paris?"}], + tools=[{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"} + }, + "required": ["city"] + } + } + }] +) + +# Check for tool calls +if response.choices[0].message.tool_calls: + for tc in response.choices[0].message.tool_calls: + print(f"Function: {tc.function.name}") + print(f"Arguments: {tc.function.arguments}") +``` + +## 支持的 tool parser + +使用 `--tool-call-parser` 为您的模型系列选择对应的 tool parser: + +| Parser | 别名 | 模型 | 格式 | +|--------|------|------|------| +| `auto` | | 任意模型 | 自动检测格式(依次尝试所有 parser) | +| `mistral` | | Mistral、Devstral | `[TOOL_CALLS]` JSON 数组 | +| `qwen` | `qwen3` | Qwen、Qwen3 | `` XML 或 `[Calling tool:]` | +| `llama` | `llama3`、`llama4` | Llama 3.x、4.x | `` 标签 | +| `hermes` | `nous` | Hermes、NousResearch | `` XML 包裹的 JSON | +| `deepseek` | `deepseek_v3`、`deepseek_r1` | DeepSeek V3、R1 | Unicode 分隔符 | +| `kimi` | `kimi_k2`、`moonshot` | Kimi K2、Moonshot | `<\|tool_call_begin\|>` 标记 | +| `granite` | `granite3` | IBM Granite 3.x、4.x | `<\|tool_call\|>` 或 `` | +| `nemotron` | `nemotron3` | NVIDIA Nemotron | `` | +| `xlam` | | Salesforce xLAM | 含 `tool_calls` 数组的 JSON | +| `functionary` | `meetkai` | MeetKai Functionary | 多个 function 块 | +| `glm47` | `glm4` | GLM-4.7、GLM-4.7-Flash | `` 配合 ``/`` XML | + +## 模型示例 + +### Mistral / Devstral + +```bash +# Devstral Small(针对编程和 tool use 优化) +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice --tool-call-parser mistral + +# Mistral Instruct +vllm-mlx serve mlx-community/Mistral-7B-Instruct-v0.3-4bit \ + --enable-auto-tool-choice --tool-call-parser mistral +``` + +### Qwen + +```bash +# Qwen3 +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --enable-auto-tool-choice --tool-call-parser qwen +``` + +### Llama + +```bash +# Llama 3.2 +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit \ + --enable-auto-tool-choice --tool-call-parser llama +``` + +### DeepSeek + +```bash +# DeepSeek V3 +vllm-mlx serve mlx-community/DeepSeek-V3-0324-4bit \ + --enable-auto-tool-choice --tool-call-parser deepseek +``` + +### IBM Granite + +```bash +# Granite 4.0 +vllm-mlx serve mlx-community/granite-4.0-tiny-preview-4bit \ + --enable-auto-tool-choice --tool-call-parser granite +``` + +### NVIDIA Nemotron + +```bash +# Nemotron 3 Nano +vllm-mlx serve mlx-community/NVIDIA-Nemotron-3-Nano-30B-A3B-MLX-6Bit \ + --enable-auto-tool-choice --tool-call-parser nemotron +``` + +### GLM-4.7 + +```bash +# GLM-4.7 Flash +vllm-mlx serve lmstudio-community/GLM-4.7-Flash-MLX-8bit \ + --enable-auto-tool-choice --tool-call-parser glm47 +``` + +### Kimi K2 + +```bash +# Kimi K2 +vllm-mlx serve mlx-community/Kimi-K2-Instruct-4bit \ + --enable-auto-tool-choice --tool-call-parser kimi +``` + +### Salesforce xLAM + +```bash +# xLAM +vllm-mlx serve mlx-community/xLAM-2-fc-r-4bit \ + --enable-auto-tool-choice --tool-call-parser xlam +``` + +## Auto Parser + +如果不确定使用哪个 tool parser,`auto` parser 会尝试自动检测格式: + +```bash +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --enable-auto-tool-choice --tool-call-parser auto +``` + +auto parser 按以下顺序依次尝试各种格式: + +1. Mistral(`[TOOL_CALLS]`) +2. Qwen 括号格式(`[Calling tool:]`) +3. Nemotron(``) +4. Qwen/Hermes XML(`{...}`) +5. Llama(`{...}`) +6. 原始 JSON + +## Streaming Tool Calls + +Tool calling 支持 streaming。模型生成完毕后发送 tool call 信息: + +```python +stream = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What's 25 * 17?"}], + tools=[{ + "type": "function", + "function": { + "name": "calculator", + "description": "Calculate math expressions", + "parameters": { + "type": "object", + "properties": { + "expression": {"type": "string"} + }, + "required": ["expression"] + } + } + }], + stream=True +) + +for chunk in stream: + if chunk.choices[0].delta.tool_calls: + for tc in chunk.choices[0].delta.tool_calls: + print(f"Tool call: {tc.function.name}({tc.function.arguments})") +``` + +## 处理工具返回结果 + +收到 tool call 后,执行对应函数并将结果返回给模型: + +```python +import json + +# 第一次请求,模型决定调用工具 +response = client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "What's the weather in Tokyo?"}], + tools=[weather_tool] +) + +# 获取 tool call +tool_call = response.choices[0].message.tool_calls[0] +tool_call_id = tool_call.id +function_name = tool_call.function.name +arguments = json.loads(tool_call.function.arguments) + +# 执行函数(由您自行实现) +result = get_weather(**arguments) # {"temperature": 22, "condition": "sunny"} + +# 将结果返回给模型 +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "user", "content": "What's the weather in Tokyo?"}, + {"role": "assistant", "tool_calls": [tool_call]}, + {"role": "tool", "tool_call_id": tool_call_id, "content": json.dumps(result)} + ], + tools=[weather_tool] +) + +print(response.choices[0].message.content) +# "The weather in Tokyo is sunny with a temperature of 22C." +``` + +## Think 标签处理 + +会产生 `...` reasoning 标签的模型(如 DeepSeek-R1、Qwen3、GLM-4.7)均可自动处理。tool parser 会在提取 tool call 前剥离 thinking 内容,因此 reasoning 标签不会干扰 tool call 解析。 + +即使 `` 是通过提示词注入的(即仅有闭合标签 `` 的隐式 think 标签),也同样适用。 + +## CLI 参数参考 + +| 选项 | 说明 | +|------|------| +| `--enable-auto-tool-choice` | 启用自动 tool calling | +| `--tool-call-parser` | 选择 tool parser(见上表) | + +完整选项请参阅 [CLI Reference](../reference/cli.md)。 diff --git a/docs/zh/guides/warm-prompts.md b/docs/zh/guides/warm-prompts.md new file mode 100644 index 000000000..6f5196a70 --- /dev/null +++ b/docs/zh/guides/warm-prompts.md @@ -0,0 +1,142 @@ +# Warm Prompts + +在服务器启动时预先填充 prefix cache,使 agent 发送的**第一个**请求命中已预热的缓存,而无需为其数千字节的系统提示支付完整的 prefill 开销。 + +## 适用场景 + +Agent 工作负载,如代理编码助手或推理助手的代理、MCP 服务器、多 agent 编排器,始终会发送相同的系统提示。在当前实现中,冷启动服务器收到的第一个请求需要为该系统提示支付完整的 prefill 代价。对于数十亿参数的模型,这意味着数秒的 TTFT,而此时用户正在等待其新 agent 首次响应。 + +如果您在部署时已知道各 agent 的系统提示,可将其写入一个 JSON 文件并通过 `--warm-prompts` 指向它。服务器会在启动时对每条提示执行一次 `max_tokens=1` 的聊天补全,KV cache 状态随即落入 prefix cache,后续真实请求即可通过严格前缀匹配命中缓存。 + +此功能需要 `--continuous-batching`(prefix cache 依赖该模式)。 + +## 快速示例 + +```bash +# 一次性写入您关心的 agent +cat > ~/.config/vllm-mlx/agents.json <<'JSON' +[ + [{"role": "system", "content": "You are a code assistant..."}] +] +JSON + +# 将服务器指向该文件 +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --continuous-batching \ + --warm-prompts ~/.config/vllm-mlx/agents.json +``` + +启动时您将看到: + +``` +[lifespan] Warm-up done (strict-prefix): 1 completed, 0 skipped, + 1431 prompt tokens in 0.2s +``` + +第一个共享已预热系统提示的真实请求将命中缓存,其 `tokens_saved` 接近预热提示的长度。 + +## 文件格式 + +顶层为一个 JSON 列表,每个条目本身也是一个聊天消息列表,结构与 `/v1/chat/completions` 中的 `messages` 字段相同。 + +```json +[ + [ + {"role": "system", "content": "You are a code assistant..."} + ], + [ + {"role": "system", "content": "You are a senior code reviewer..."} + ], + [ + {"role": "system", "content": "You are a planner..."}, + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "Hello, what are we planning?"} + ] +] +``` + +单条系统提示是最常见的用法。多轮历史也受支持,适用于需要预热特定对话开头的场景(少样本示例、持续运行的助手角色等)。 + +## 规模建议 + +预热提示通过 `asyncio.gather` **并发**处理,因此 N 条条目会在启动时触发 N 个并发 prefill,每个 prefill 会为其提示长度分配 KV cache。 + +**建议条目数为 1 至 3 条**,足以覆盖典型 agent 部署的热路径(每个角色一条)。在内存紧张的模型上,过大的 warm-prompts 文件可能在启动时耗尽可用空间。 + +如需预热数十个角色,请提交一个 issue 并说明您的工作负载,我们可以添加 `--warm-prompts-concurrency=N` 上限参数。 + +## 基准测试 + +**测试环境:** M4 Max,128 GB 统一内存。每次测量使用两个独立服务器(冷启动与预热),隔离冷启动。`long` 提示集(约 2500 个用户 token)前置约 1700 token 的系统提示以匹配预热提示。`max_tokens=128`。bench-serve 使用 `--skip-preflight-token-count`,避免 count_prompt_tokens 预检污染缓存。 + +| 模型 | 并发 | 冷启动 TTFT | 预热 TTFT | 加速比 | +|------|-----:|----------:|----------:|------:| +| Qwen3-0.6B-8bit | 1 | 563 ms | 419 ms | 1.34x | +| Qwen3-0.6B-8bit | 4 | 1 723 ms | 1 282 ms | 1.34x | +| Qwen3-0.6B-8bit | 8 | 3 708 ms | 2 661 ms | 1.39x | +| Llama-3.2-3B-Instruct-4bit | 1 | 1 754 ms | 1 060 ms | 1.65x | +| Llama-3.2-3B-Instruct-4bit | 4 | 5 926 ms | 3 945 ms | 1.50x | +| Llama-3.2-3B-Instruct-4bit | 8 | 15 161 ms | 9 820 ms | 1.54x | +| Qwen3-4B-4bit | 1 | 4 937 ms | 2 191 ms | 2.25x | +| Qwen3-4B-4bit | 4 | 12 535 ms | 9 623 ms | 1.30x | +| Qwen3-4B-4bit | 8 | 38 148 ms | 23 878 ms | 1.60x | +| Qwen3.6-35B-A3B-4bit (MoE/hybrid) | 1 | 2 400 ms | 1 603 ms | 1.50x | +| Qwen3.6-35B-A3B-4bit | 4 | 8 735 ms | 6 054 ms | 1.44x | +| Qwen3.6-35B-A3B-4bit | 8 | 22 419 ms | 14 409 ms | 1.56x | + +全部 12 项配置均有提升。当提示占总长度比例最高时(并发=1,长系统提示),TTFT 节省最为显著,在并发负载下仍有实质性收益。 + +**生成 tok/s** 对于稠密模型基本持平(误差在 ±5% 以内)。Qwen3.6-35B-A3B(MoE)在并发数大于等于 4 时出现 20 至 35% 的解码速度下降,原因似乎是 MoE 路由与批量调度之间的交互。对于 agent 工作负载,TTFT 节省仍主导端到端延迟,但若您的工作流在高并发下以解码为瓶颈,请注意这一点。 + +## 工作原理 + +朴素的预热方式,即用占位用户消息渲染聊天模板并缓存 token,对于混合 SSM+attention 模型(Qwen3.5-MoE、Qwen3.6-MoE)不适用。这类模型的缓存层包含无法裁剪的 SSM 状态,因此 `memory_cache.py` 禁用了 LCP 匹配。占位用户内容与真实用户内容不同,基于 token 的缓存条目不再是任何真实请求的严格前缀。 + +本预热器会用两个不同的用户内容(`"__PROBE_A__"` 和 `"__PROBE_B__"`)**两次**渲染聊天模板,找到两个字符串开始发散的字符位置,并在该边界处截断第一次渲染的结果。这段截断后的字符串,即用户内容被插入之前的全部内容,是发送给引擎的内容。 + +由于引擎的真实请求路径同样使用 `tokenize=False` 渲染模板,再由分词器对结果进行编码,因此预热生成的 token 保证是任何具有匹配系统提示且聊天历史为空的真实请求的严格前缀。严格前缀匹配适用于所有缓存层类型,包括禁用 LCP 的混合路径。 + +## 管理操作 + +### 清除内存中的 prefix cache + +```bash +curl -X DELETE http://localhost:8000/v1/cache/prefix +``` + +若服务器以 `--warm-prompts` 启动,清除后会在后台重新执行预热。响应会立即返回,不等待重新预热完成。 + +响应: + +```json +{"status": "cleared", "rewarm_scheduled": true} +``` + +### 查看缓存状态 + +```bash +curl http://localhost:8000/v1/status | jq '.cache' +``` + +使用 warm-prompts 启动后,在第一个用户请求到来之前,您将看到 `entry_count > 0`。 + +## 针对您自己的场景进行基准测试 + +如需测量对您的模型和提示的实际影响,请使用 `bench-serve`: + +```bash +# 冷启动:不使用 warm-prompts +vllm-mlx serve MODEL --continuous-batching & +vllm-mlx bench-serve --prompts long --concurrency 1,4 \ + --system-prompt-file my-system.txt --tag cold \ + --output cold.csv --format csv + +# 预热:相同服务器配置 + --warm-prompts +vllm-mlx serve MODEL --continuous-batching \ + --warm-prompts ~/.config/vllm-mlx/agents.json & +vllm-mlx bench-serve --prompts long --concurrency 1,4 \ + --system-prompt-file my-system.txt --tag warm \ + --output warm.csv --format csv +``` + +设置 `--system-prompt-file` 时会自动启用 `--skip-preflight-token-count`,防止 `count_prompt_tokens` 预检污染缓存。比较 `cold.csv` 与 `warm.csv` 即可评估您工作负载的实际效果。 diff --git a/docs/zh/index.md b/docs/zh/index.md new file mode 100644 index 000000000..dcb339eb3 --- /dev/null +++ b/docs/zh/index.md @@ -0,0 +1,66 @@ +# vLLM-MLX 文档 + +**Apple Silicon 的 MLX 推理后端** - 在 Mac 上对文本、图像、视频和音频进行 GPU 加速推理 + +## 什么是 vLLM-MLX? + +vllm-mlx 通过集成以下组件,为 vLLM 带来原生 Apple Silicon GPU 加速: + +- **[MLX](https://github.com/ml-explore/mlx)**:Apple 的机器学习框架,具有统一内存和 Metal 内核 +- **[mlx-lm](https://github.com/ml-explore/mlx-lm)**:经过优化的 LLM 推理,支持 KV cache 和量化 +- **[mlx-vlm](https://github.com/Blaizzy/mlx-vlm)**:用于多模态推理的视觉语言模型 VLM +- **[mlx-audio](https://github.com/Blaizzy/mlx-audio)**:基于原生语音的 TTS 和 STT +- **[mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings)**:用于语义搜索和 RAG 的文本 embeddings + +## 主要特性 + +- **多模态** - 在同一平台上处理文本、图像、视频和音频 +- **原生 GPU 加速**,支持 Apple Silicon(M1、M2、M3、M4) +- **原生 TTS 语音** - 支持西班牙语、法语、中文、日语及其他 5 种语言 +- **OpenAI API 兼容** - 可直接替换 OpenAI 客户端 +- **Embeddings** - 兼容 OpenAI 的 `/v1/embeddings` 端点 +- **MCP Tool Calling** - 通过 Model Context Protocol 集成外部工具 +- **Paged KV Cache** - 支持前缀共享的高效内存缓存 +- **Continuous Batching** - 为多并发用户提供高吞吐量 + +## 快速链接 + +### 入门指南 +- [安装](getting-started/installation.md) +- [快速开始](getting-started/quickstart.md) + +### 用户指南 +- [兼容 OpenAI 的服务器](guides/server.md) +- [Python API](guides/python-api.md) +- [多模态(图像与视频)](guides/multimodal.md) +- [音频(STT/TTS)](guides/audio.md) +- [Embeddings](guides/embeddings.md) +- [Reasoning 模型](guides/reasoning.md) +- [Tool Calling](guides/tool-calling.md) +- [MCP 与 Tool Calling](guides/mcp-tools.md) +- [Continuous Batching](guides/continuous-batching.md) + +### 参考文档 +- [CLI 命令](reference/cli.md) +- [支持的模型](reference/models.md) +- [配置说明](reference/configuration.md) + +### 基准测试 +- [LLM 基准测试](benchmarks/llm.md) +- [图像基准测试](benchmarks/image.md) +- [视频基准测试](benchmarks/video.md) +- [音频基准测试](benchmarks/audio.md) + +### 开发者文档 +- [架构设计](../development/architecture.md) +- [贡献指南](../development/contributing.md) + +## 环境要求 + +- 搭载 Apple Silicon 的 macOS(M1/M2/M3/M4) +- Python 3.10 及以上 +- 推荐 8GB 及以上内存 + +## 许可证 + +Apache 2.0,详情请参阅 [LICENSE](../../LICENSE)。 diff --git a/docs/zh/reference/cli.md b/docs/zh/reference/cli.md new file mode 100644 index 000000000..0117c506d --- /dev/null +++ b/docs/zh/reference/cli.md @@ -0,0 +1,210 @@ +# CLI 参考 + +## 命令概览 + +| 命令 | 说明 | +|---------|-------------| +| `vllm-mlx serve` | 启动兼容 OpenAI 的服务器 | +| `vllm-mlx-bench` | 运行性能基准测试 | +| `vllm-mlx-chat` | 启动 Gradio 对话界面 | + +## `vllm-mlx serve` + +启动兼容 OpenAI 的 API 服务器。 + +### 用法 + +```bash +vllm-mlx serve [options] +``` + +### 选项 + +| 选项 | 说明 | 默认值 | +|--------|-------------|---------| +| `--served-model-name` | 通过 OpenAI API 暴露的自定义模型名称。未设置时使用模型路径作为名称。 | None | +| `--port` | 服务器端口 | 8000 | +| `--host` | 服务器主机 | 127.0.0.1 | +| `--api-key` | 用于身份验证的 API 密钥 | None | +| `--rate-limit` | 每个客户端每分钟的请求数(0 表示禁用) | 0 | +| `--timeout` | 请求超时时间,单位为秒 | 300 | +| `--enable-metrics` | 在 `/metrics` 上暴露 Prometheus 指标 | False | +| `--continuous-batching` | 为多用户启用 continuous batching | False | +| `--cache-memory-mb` | 缓存内存上限,单位为 MB | Auto | +| `--cache-memory-percent` | 用于缓存的内存占比 | 0.20 | +| `--no-memory-aware-cache` | 使用旧版按条目数计数的缓存 | False | +| `--use-paged-cache` | 启用 paged KV cache | False | +| `--max-tokens` | 默认最大 token 数 | 32768 | +| `--max-request-tokens` | API 客户端可传入的最大 `max_tokens` | 32768 | +| `--stream-interval` | 每个 streaming 分块包含的 token 数 | 1 | +| `--mcp-config` | MCP 配置文件路径 | None | +| `--paged-cache-block-size` | 每个缓存块包含的 token 数 | 64 | +| `--max-cache-blocks` | 最大缓存块数 | 1000 | +| `--max-num-seqs` | 最大并发序列数 | 256 | +| `--default-temperature` | 请求未指定时的默认 temperature | None | +| `--default-top-p` | 请求未指定时的默认 top_p | None | +| `--max-audio-upload-mb` | `/v1/audio/transcriptions` 接受的最大音频上传大小 | 25 | +| `--max-tts-input-chars` | `/v1/audio/speech` 接受的最大文本长度 | 4096 | +| `--reasoning-parser` | reasoning 模型的解析器(`qwen3`、`deepseek_r1`) | None | +| `--embedding-model` | 启动时预加载 embeddings 模型 | None | +| `--enable-auto-tool-choice` | 启用自动 tool calling | False | +| `--tool-call-parser` | tool call 解析器(`auto`、`mistral`、`qwen`、`llama`、`hermes`、`deepseek`、`kimi`、`granite`、`nemotron`、`xlam`、`functionary`、`glm47`) | None | + +### 示例 + +```bash +# Simple mode (single user, max throughput) +# Model path is used as the model name in the OpenAI API (e.g. model="mlx-community/Llama-3.2-3B-Instruct-4bit") +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit + +Model will show up as 'mlx-community/Llama-3.2-3B-Instruct-4bit' in the `/v1/models` API endpoint. View with `curl http://localhost:8000/v1/models` or similar. + +# With a custom API model name (model is accessed as "my-model" via the OpenAI API) +# --served-model-name sets the name clients must use when calling the API (e.g. model="my-model") +vllm-mlx serve --served-model-name my-model mlx-community/Llama-3.2-3B-Instruct-4bit +# Note: Model will show up as 'my-model' in the `/v1/models` API endpoint. + +# Continuous batching (multiple users) +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --continuous-batching + +# With memory limit for large models +vllm-mlx serve mlx-community/GLM-4.7-Flash-4bit \ + --continuous-batching \ + --cache-memory-mb 2048 + +# Production with paged cache +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching \ + --use-paged-cache \ + --port 8000 + +# With MCP tools +vllm-mlx serve mlx-community/Qwen3-4B-4bit --mcp-config mcp.json + +# Multimodal model +vllm-mlx serve mlx-community/Qwen3-VL-4B-Instruct-3bit + +# Reasoning model (separates thinking from answer) +vllm-mlx serve mlx-community/Qwen3-8B-4bit --reasoning-parser qwen3 + +# DeepSeek reasoning model +vllm-mlx serve mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit --reasoning-parser deepseek_r1 + +# Tool calling with Mistral/Devstral +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice --tool-call-parser mistral + +# Tool calling with Granite +vllm-mlx serve mlx-community/granite-4.0-tiny-preview-4bit \ + --enable-auto-tool-choice --tool-call-parser granite + +# With API key authentication +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --api-key your-secret-key + +# Expose Prometheus metrics +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit --enable-metrics + +# Production setup with security options +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --api-key your-secret-key \ + --rate-limit 60 \ + --timeout 120 \ + --continuous-batching +``` + +### 安全 + +设置 `--api-key` 后,所有 API 请求都需要携带 `Authorization: Bearer ` 请求头: + +```python +from openai import OpenAI + +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="your-secret-key" # Must match --api-key +) +``` + +或使用 curl: + +```bash +curl http://localhost:8000/v1/models \ + -H "Authorization: Bearer your-secret-key" +``` + +## `vllm-mlx-bench` + +运行性能基准测试。 + +### 用法 + +```bash +vllm-mlx-bench --model [options] +``` + +### 选项 + +| 选项 | 说明 | 默认值 | +|--------|-------------|---------| +| `--model` | 模型名称 | 必填 | +| `--prompts` | 提示词数量 | 5 | +| `--max-tokens` | 每条提示词的最大 token 数 | 256 | +| `--quick` | 快速基准测试模式 | False | +| `--video` | 运行视频基准测试 | False | +| `--video-url` | 自定义视频 URL | None | +| `--video-path` | 自定义视频路径 | None | + +### 示例 + +```bash +# LLM benchmark +vllm-mlx-bench --model mlx-community/Llama-3.2-1B-Instruct-4bit + +# Quick benchmark +vllm-mlx-bench --model mlx-community/Llama-3.2-1B-Instruct-4bit --quick + +# Image benchmark (auto-detected for VLM models) +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit + +# Video benchmark +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit --video + +# Custom video +vllm-mlx-bench --model mlx-community/Qwen3-VL-8B-Instruct-4bit \ + --video --video-url https://example.com/video.mp4 +``` + +## `vllm-mlx-chat` + +启动 Gradio 对话界面。 + +### 用法 + +```bash +vllm-mlx-chat --served-model-name [options] +``` + +### 选项 + +| 选项 | 说明 | 默认值 | +|--------|-------------|---------| +| `--model` | 模型名称 | 必填 | +| `--port` | Gradio 端口 | 7860 | +| `--text-only` | 禁用多模态功能 | False | + +### 示例 + +```bash +# Multimodal chat (text + images + video) +vllm-mlx-chat --served-model-name mlx-community/Qwen3-VL-4B-Instruct-3bit + +# Text-only chat +vllm-mlx-chat --served-model-name mlx-community/Llama-3.2-3B-Instruct-4bit --text-only +``` + +## 环境变量 + +| 变量 | 说明 | +|----------|-------------| +| `VLLM_MLX_TEST_MODEL` | 测试使用的模型 | +| `HF_TOKEN` | HuggingFace token | diff --git a/docs/zh/reference/configuration.md b/docs/zh/reference/configuration.md new file mode 100644 index 000000000..a2a99b480 --- /dev/null +++ b/docs/zh/reference/configuration.md @@ -0,0 +1,189 @@ +# 配置参考 + +## 服务器配置 + +### 基本选项 + +| 选项 | 说明 | 默认值 | +|--------|-------------|---------| +| `--host` | 服务器主机地址 | `127.0.0.1` | +| `--port` | 服务器端口 | `8000` | +| `--max-tokens` | 默认最大 token 数 | `32768` | +| `--max-request-tokens` | API 客户端可传入的最大 `max_tokens` 值 | `32768` | +| `--default-temperature` | 请求未指定时使用的默认 temperature | None | +| `--default-top-p` | 请求未指定时使用的默认 top_p | None | + +### 安全选项 + +| 选项 | 说明 | 默认值 | +|--------|-------------|---------| +| `--api-key` | 用于身份验证的 API key | None | +| `--rate-limit` | 每个客户端每分钟的请求数(0 表示禁用) | `0` | +| `--timeout` | 请求超时时间(秒) | `300` | +| `--enable-metrics` | 在 `/metrics` 上暴露 Prometheus 指标 | `false` | +| `--max-audio-upload-mb` | `/v1/audio/transcriptions` 接受的最大音频上传大小 | `25` | +| `--max-tts-input-chars` | `/v1/audio/speech` 接受的最大文本长度 | `4096` | + +### 批处理选项 + +| 选项 | 说明 | 默认值 | +|--------|-------------|---------| +| `--continuous-batching` | 启用 continuous batching | `false` | +| `--stream-interval` | 每个 streaming 分块包含的 token 数 | `1` | +| `--max-num-seqs` | 最大并发序列数 | `256` | + +### 缓存选项 + +| 选项 | 说明 | 默认值 | +|--------|-------------|---------| +| `--cache-memory-mb` | 缓存内存上限(MB) | 自动 | +| `--cache-memory-percent` | 分配给缓存的内存比例 | `0.20` | +| `--no-memory-aware-cache` | 使用旧版基于条目数量的缓存 | `false` | +| `--use-paged-cache` | 启用 paged KV cache | `false` | +| `--paged-cache-block-size` | 每个块包含的 token 数 | `64` | +| `--max-cache-blocks` | 最大块数量 | `1000` | + +### 工具调用选项 + +| 选项 | 说明 | 默认值 | +|--------|-------------|---------| +| `--enable-auto-tool-choice` | 启用自动工具调用 | `false` | +| `--tool-call-parser` | 工具调用解析器(参见 [工具调用](../guides/tool-calling.md)) | None | + +### 推理选项 + +| 选项 | 说明 | 默认值 | +|--------|-------------|---------| +| `--reasoning-parser` | 推理模型解析器(`qwen3`、`deepseek_r1`) | None | + +### 嵌入选项 + +| 选项 | 说明 | 默认值 | +|--------|-------------|---------| +| `--embedding-model` | 启动时预加载嵌入模型 | None | + +### MCP 选项 + +| 选项 | 说明 | 默认值 | +|--------|-------------|---------| +| `--mcp-config` | MCP 配置文件路径 | None | + +## MCP 配置 + +创建 `mcp.json`: + +```json +{ + "mcpServers": { + "server-name": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-name", "arg1"], + "env": { + "ENV_VAR": "value" + } + } + } +} +``` + +### MCP 服务器字段 + +| 字段 | 说明 | 是否必填 | +|-------|-------------|----------| +| `command` | 可执行命令 | 是 | +| `args` | 命令参数 | 是 | +| `env` | 环境变量 | 否 | + +## API 请求选项 + +### 聊天补全 + +| 参数 | 说明 | 默认值 | +|-----------|-------------|---------| +| `model` | 模型名称 | 必填 | +| `messages` | 聊天消息 | 必填 | +| `max_tokens` | 最大生成 token 数 | 256 | +| `temperature` | 采样 temperature | 模型默认值 | +| `top_p` | Nucleus sampling | 模型默认值 | +| `stream` | 启用 streaming | `true` | +| `stop` | 停止序列 | None | +| `tools` | 工具定义 | None | +| `response_format` | 输出格式(`json_object`、`json_schema`) | None | + +### 多模态选项 + +| 参数 | 说明 | 默认值 | +|-----------|-------------|---------| +| `video_fps` | 每秒帧数 | 2.0 | +| `video_max_frames` | 最大帧数 | 32 | + +## 环境变量 + +| 变量 | 说明 | +|----------|-------------| +| `VLLM_MLX_TEST_MODEL` | 测试使用的默认模型 | +| `HF_TOKEN` | HuggingFace 身份验证 token | +| `OPENAI_API_KEY` | 设为任意值以兼容 SDK | + +## 配置示例 + +### 开发环境(单用户) + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit +``` + +### 生产环境(多用户) + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching \ + --use-paged-cache \ + --api-key your-secret-key \ + --rate-limit 60 \ + --port 8000 +``` + +### 使用工具调用 + +```bash +vllm-mlx serve mlx-community/Devstral-Small-2507-4bit \ + --enable-auto-tool-choice \ + --tool-call-parser mistral \ + --continuous-batching +``` + +### 使用 MCP 工具 + +```bash +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --mcp-config mcp.json \ + --enable-auto-tool-choice \ + --tool-call-parser qwen \ + --continuous-batching +``` + +### 推理模型 + +```bash +vllm-mlx serve mlx-community/Qwen3-8B-4bit \ + --reasoning-parser qwen3 \ + --continuous-batching +``` + +### 使用嵌入 + +```bash +vllm-mlx serve mlx-community/Qwen3-4B-4bit \ + --embedding-model mlx-community/multilingual-e5-small-mlx \ + --continuous-batching +``` + +### 高吞吐量 + +```bash +vllm-mlx serve mlx-community/Qwen3-0.6B-8bit \ + --continuous-batching \ + --stream-interval 5 \ + --max-num-seqs 256 +``` diff --git a/docs/zh/reference/models.md b/docs/zh/reference/models.md new file mode 100644 index 000000000..afde19a4b --- /dev/null +++ b/docs/zh/reference/models.md @@ -0,0 +1,99 @@ +# 支持的模型 + +所有来自 [mlx-community on HuggingFace](https://huggingface.co/mlx-community/models) 的量化模型均兼容。 + +在以下地址浏览数千个预优化模型:**https://huggingface.co/mlx-community/models** + +## 语言模型(通过 mlx-lm) + +| 模型系列 | 规格 | 量化方式 | +|--------------|-------|--------------| +| Llama 3.x, 4.x | 1B, 3B, 8B, 70B | 4-bit | +| Mistral / Devstral | 7B, Mixtral 8x7B | 4-bit, 8-bit | +| Qwen2/Qwen3 | 0.5B 至 72B | 多种 | +| DeepSeek V3, R1 | 7B, 33B, 67B | 4-bit | +| Gemma 2, 3, 4 | 2B, 9B, 27B | 4-bit | +| GLM-4.7 | Flash, Base | 4-bit, 8-bit | +| Kimi K2 | 多种 | 4-bit | +| Phi-3 | 3.8B, 14B | 4-bit | +| Granite 3.x, 4.x | 多种 | 4-bit | +| Nemotron | 3 Nano 30B | 6-bit | + +### 推荐模型 + +| 使用场景 | 模型 | 内存 | +|----------|-------|--------| +| 快速/轻量 | `mlx-community/Qwen3-0.6B-8bit` | ~0.7 GB | +| 均衡 | `mlx-community/Llama-3.2-3B-Instruct-4bit` | ~1.8 GB | +| 高质量 | `mlx-community/Llama-3.1-8B-Instruct-4bit` | ~4.5 GB | +| 大型 | `mlx-community/Qwen3-30B-A3B-4bit` | ~16 GB | + +## 多模态模型(通过 mlx-vlm) + +| 模型系列 | 示例模型 | +|--------------|----------------| +| **Qwen-VL** | `Qwen3-VL-4B-Instruct-3bit`, `Qwen3-VL-8B-Instruct-4bit`, `Qwen2-VL-2B/7B-Instruct-4bit` | +| **LLaVA** | `llava-1.5-7b-4bit`, `llava-v1.6-mistral-7b-4bit`, `llava-llama-3-8b-v1_1-4bit` | +| **Idefics** | `Idefics3-8B-Llama3-4bit`, `idefics2-8b-4bit` | +| **Gemma 4** | `gemma-4-e2b-it-mxfp4`(视觉 + 音频) | +| **PaliGemma** | `paligemma2-3b-mix-224-4bit`, `paligemma-3b-mix-224-8bit` | +| **Pixtral** | `pixtral-12b-4bit`, `pixtral-12b-8bit` | +| **Molmo** | `Molmo-7B-D-0924-4bit`, `Molmo-7B-D-0924-8bit` | +| **Phi-3 Vision** | `Phi-3-vision-128k-instruct-4bit` | +| **DeepSeek-VL** | `deepseek-vl-7b-chat-4bit`, `deepseek-vl2-small-4bit` | + +### 推荐 VLM 模型 + +| 使用场景 | 模型 | 内存 | +|----------|-------|--------| +| 快速/轻量 | `mlx-community/Qwen3-VL-4B-Instruct-3bit` | ~3 GB | +| 均衡 | `mlx-community/Qwen3-VL-8B-Instruct-4bit` | ~6 GB | +| 高质量 | `mlx-community/Qwen3-VL-30B-A3B-Instruct-6bit` | ~20 GB | + +## Embedding 模型(通过 mlx-embeddings) + +| 模型系列 | 示例模型 | +|--------------|----------------| +| **BERT** | `mlx-community/bert-base-uncased-mlx` | +| **XLM-RoBERTa** | `mlx-community/multilingual-e5-small-mlx`, `mlx-community/multilingual-e5-large-mlx` | +| **ModernBERT** | `mlx-community/ModernBERT-base-mlx` | + +## 音频模型(通过 mlx-audio) + +| 类型 | 模型系列 | 示例模型 | +|------|--------------|----------------| +| **STT** | Whisper | `mlx-community/whisper-large-v3-turbo` | +| **STT** | Parakeet | `mlx-community/parakeet-tdt-0.6b-v2` | +| **TTS** | Kokoro | `prince-canuma/Kokoro-82M` | +| **TTS** | Chatterbox | `chatterbox/chatterbox-tts-0.1` | + +## 模型自动检测 + +vllm-mlx 通过名称模式自动检测多模态模型: +- 包含 "VL"、"Vision"、"vision" +- 包含 "llava"、"idefics"、"paligemma" +- 包含 "pixtral"、"molmo"、"deepseek-vl" +- 包含 "MedGemma"、"Gemma-3"、"Gemma-4"(多模态变体) + +## 使用模型 + +### 从 HuggingFace 加载 + +```bash +vllm-mlx serve mlx-community/Llama-3.2-3B-Instruct-4bit +``` + +### 本地路径 + +```bash +vllm-mlx serve /path/to/local/model +``` + +## 查找模型 + +按以下关键词筛选 mlx-community 模型: +- **LLM**:`Llama`、`Qwen`、`Mistral`、`Phi`、`Gemma`、`DeepSeek`、`GLM`、`Kimi`、`Granite`、`Nemotron` +- **VLM**:`-VL-`、`llava`、`paligemma`、`pixtral`、`molmo`、`idefics`、`deepseek-vl`、`MedGemma` +- **Embedding**:`e5`、`bert`、`ModernBERT` +- **规格**:`1B`、`3B`、`7B`、`8B`、`70B` +- **量化方式**:`4bit`、`8bit`、`bf16` diff --git a/pyproject.toml b/pyproject.toml index 0e58749da..35bd6f69d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vllm-mlx" -version = "0.2.8" +version = "0.2.9" description = "vLLM-like inference for Apple Silicon - GPU-accelerated Text, Image, Video & Audio on Mac" readme = "README.md" license = {text = "Apache-2.0"} @@ -55,12 +55,14 @@ dependencies = [ "mcp>=1.0.0", # JSON Schema validation for structured output "jsonschema>=4.0.0", + # Constrained decoding (grammar-guided generation for response_format) + "lm-format-enforcer>=0.10.9", # pytz is required by gradio but not declared as a dependency # See: https://github.com/waybarrios/vllm-mlx/issues/23 "pytz>=2024.1", # Note: mlx-audio moved to optional [audio] deps due to mlx-lm version conflict # See: https://github.com/waybarrios/vllm-mlx/issues/19 - "mlx-embeddings>=0.0.5" + "mlx-embeddings>=0.0.5", ] [project.optional-dependencies] @@ -70,6 +72,7 @@ dev = [ "black>=23.0.0", "ruff>=0.1.0", "mypy>=1.0.0", + "pre-commit>=4.6.0", ] vllm = [ "vllm>=0.4.0", @@ -117,6 +120,9 @@ vllm-mlx-bench = "vllm_mlx.benchmark:main" where = ["."] include = ["vllm_mlx*"] +[tool.setuptools.package-data] +"vllm_mlx.bench_serve_prompts" = ["*.json"] + [tool.black] line-length = 88 target-version = ["py310", "py311", "py312", "py313"] @@ -125,8 +131,8 @@ target-version = ["py310", "py311", "py312", "py313"] line-length = 88 [tool.ruff.lint] -select = ["E", "F", "W", "I", "N", "UP", "B", "SIM"] -ignore = ["E501", "B905"] +select = ["E", "F", "W"] +ignore = ["B905", "E402", "E501", "E731", "F811", "F841"] [tool.mypy] python_version = "3.10" diff --git a/tests/test_api_models.py b/tests/test_api_models.py index fdb756aee..88d67cff8 100644 --- a/tests/test_api_models.py +++ b/tests/test_api_models.py @@ -680,6 +680,43 @@ def test_chat_completion_chunk_serializes_reasoning_content_only(self): assert delta["reasoning_content"] == "thinking" assert "reasoning" not in delta + def test_assistant_message_excludes_null_tool_calls(self): + msg = AssistantMessage(content="Hello!") + data = msg.model_dump() + assert "tool_calls" not in data + assert "reasoning_content" not in data + + def test_assistant_message_excludes_null_reasoning(self): + msg = AssistantMessage(content="Hello!") + data = msg.model_dump() + assert "reasoning_content" not in data + + def test_chunk_delta_excludes_null_fields(self): + delta = ChatCompletionChunkDelta(role="assistant") + data = delta.model_dump() + assert data == {"role": "assistant"} + assert "content" not in data + assert "tool_calls" not in data + assert "reasoning_content" not in data + + def test_chunk_delta_empty_serializes_to_empty_dict(self): + delta = ChatCompletionChunkDelta() + data = delta.model_dump() + assert data == {} + + def test_chunk_finish_reason_null_preserved(self): + chunk = ChatCompletionChunk( + model="test", + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta(role="assistant"), + finish_reason=None, + ) + ], + ) + data = chunk.model_dump() + assert data["choices"][0]["finish_reason"] is None + def test_response_format_json_schema_alias(self): schema = ResponseFormatJsonSchema( name="test", diff --git a/tests/test_async_markers.py b/tests/test_async_markers.py new file mode 100644 index 000000000..85d8870b7 --- /dev/null +++ b/tests/test_async_markers.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Guard the test suite's async marker contract.""" + +from __future__ import annotations + +import ast +from pathlib import Path + + +def _is_pytest_mark_asyncio(node: ast.AST) -> bool: + """Return True when a decorator is exactly pytest.mark.asyncio.""" + return ( + isinstance(node, ast.Attribute) + and node.attr == "asyncio" + and isinstance(node.value, ast.Attribute) + and node.value.attr == "mark" + and isinstance(node.value.value, ast.Name) + and node.value.value.id == "pytest" + ) + + +def test_async_tests_use_anyio_markers(): + """The suite should not depend on pytest-asyncio anywhere.""" + offenders: list[str] = [] + tests_dir = Path(__file__).parent + + for path in sorted(tests_dir.glob("test_*.py")): + tree = ast.parse(path.read_text(), filename=str(path)) + for node in ast.walk(tree): + if isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef)): + if any( + _is_pytest_mark_asyncio(decorator) + for decorator in node.decorator_list + ): + offenders.append(f"{path.name}:{node.lineno}") + + assert ( + offenders == [] + ), "Found pytest.mark.asyncio after the suite migrated to AnyIO: " + ", ".join( + offenders + ) diff --git a/tests/test_audio_limits.py b/tests/test_audio_limits.py new file mode 100644 index 000000000..b1759f13d --- /dev/null +++ b/tests/test_audio_limits.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for audio endpoint resource limits.""" + +from pathlib import Path + +import pytest +from fastapi import HTTPException + +from vllm_mlx.audio_limits import ( + DEFAULT_MAX_AUDIO_UPLOAD_MB, + DEFAULT_MAX_TTS_INPUT_CHARS, + save_upload_with_limit, + validate_tts_input_length, +) + + +class FakeUpload: + def __init__(self, chunks: list[bytes], filename: str = "audio.wav"): + self._chunks = list(chunks) + self.filename = filename + + async def read(self, _size: int = -1) -> bytes: + if not self._chunks: + return b"" + return self._chunks.pop(0) + + +class TestAudioUploadLimits: + @pytest.mark.anyio + async def test_save_upload_with_limit_writes_file(self): + upload = FakeUpload([b"a" * 8, b"b" * 4]) + + path = await save_upload_with_limit(upload, max_bytes=32) + + try: + assert Path(path).read_bytes() == b"a" * 8 + b"b" * 4 + finally: + Path(path).unlink(missing_ok=True) + + @pytest.mark.anyio + async def test_save_upload_with_limit_rejects_oversize_and_cleans_up(self): + upload = FakeUpload([b"a" * 16, b"b" * 16, b"c"]) + + with pytest.raises(HTTPException) as exc_info: + await save_upload_with_limit(upload, max_bytes=32) + + assert exc_info.value.status_code == 413 + assert "Audio upload too large" in exc_info.value.detail + + +class TestTTSInputLimits: + def test_validate_tts_input_length_accepts_short_text(self): + validate_tts_input_length("hello", max_chars=16) + + def test_validate_tts_input_length_rejects_oversized_text(self): + with pytest.raises(HTTPException) as exc_info: + validate_tts_input_length("x" * 17, max_chars=16) + + assert exc_info.value.status_code == 413 + assert "TTS input too long" in exc_info.value.detail + + +class TestAudioLimitParsers: + def test_top_level_cli_exposes_audio_limit_flags(self): + from vllm_mlx.cli import create_parser + + parser = create_parser() + args = parser.parse_args( + [ + "serve", + "mlx-community/Llama-3.2-3B-Instruct-4bit", + "--max-audio-upload-mb", + "12", + "--max-tts-input-chars", + "2048", + ] + ) + + assert args.max_audio_upload_mb == 12 + assert args.max_tts_input_chars == 2048 + + def test_standalone_server_parser_defaults(self): + from vllm_mlx.server import create_parser + + parser = create_parser() + args = parser.parse_args( + ["--model", "mlx-community/Llama-3.2-3B-Instruct-4bit"] + ) + + assert args.max_audio_upload_mb == DEFAULT_MAX_AUDIO_UPLOAD_MB + assert args.max_tts_input_chars == DEFAULT_MAX_TTS_INPUT_CHARS diff --git a/tests/test_batched_engine.py b/tests/test_batched_engine.py index 73a7e8ffa..a78d83096 100644 --- a/tests/test_batched_engine.py +++ b/tests/test_batched_engine.py @@ -39,7 +39,7 @@ def _make_mock_request_output( mock.finish_reason = finish_reason return mock - @pytest.mark.asyncio + @pytest.mark.anyio async def test_tokens_field_is_populated(self): """tokens should contain the output token IDs from AsyncEngineCore.""" engine = self._make_engine() @@ -56,7 +56,7 @@ async def test_tokens_field_is_populated(self): assert result.tokens == token_ids - @pytest.mark.asyncio + @pytest.mark.anyio async def test_tokens_field_empty_when_no_tokens_generated(self): """tokens should be an empty list when output_token_ids is empty.""" engine = self._make_engine() @@ -70,7 +70,7 @@ async def test_tokens_field_empty_when_no_tokens_generated(self): assert result.tokens == [] - @pytest.mark.asyncio + @pytest.mark.anyio async def test_other_output_fields_still_populated(self): """Existing fields (text, prompt_tokens, etc.) must remain correct.""" engine = self._make_engine() @@ -92,3 +92,35 @@ async def test_other_output_fields_still_populated(self): assert result.prompt_tokens == 7 assert result.completion_tokens == 1 assert result.finish_reason == "stop" + + +class TestBatchedEngineCacheRestore: + def _make_mllm_engine(self): + from vllm_mlx.engine.batched import BatchedEngine + + with patch("vllm_mlx.engine.batched.is_mllm_model", return_value=True): + engine = BatchedEngine("test-mllm") + + engine._loaded = True + engine._is_mllm = True + return engine + + def test_load_cache_from_disk_bootstraps_mllm_batch_generator(self): + engine = self._make_mllm_engine() + + prefix_cache = MagicMock() + prefix_cache.load_from_disk.return_value = 2 + scheduler = MagicMock() + scheduler.batch_generator = None + + def ensure_batch_generator(): + scheduler.batch_generator = MagicMock(prefix_cache=prefix_cache) + + scheduler._ensure_batch_generator.side_effect = ensure_batch_generator + engine._mllm_scheduler = scheduler + + loaded = engine.load_cache_from_disk("/tmp/cache") + + assert loaded == 2 + scheduler._ensure_batch_generator.assert_called_once_with() + prefix_cache.load_from_disk.assert_called_once_with("/tmp/cache") diff --git a/tests/test_batching.py b/tests/test_batching.py index 6cb536aa5..84e118a6f 100644 --- a/tests/test_batching.py +++ b/tests/test_batching.py @@ -833,6 +833,84 @@ async def test_engine_context_manager(self, mock_model_and_tokenizer): assert not engine.engine.is_running() + async def test_stream_outputs_consumer_break_after_finished_does_not_abort(self): + """Breaking after a finished output is normal consumption, not orphaning.""" + from vllm_mlx.engine_core import EngineCore + from vllm_mlx.output_collector import RequestOutputCollector + from vllm_mlx.request import RequestOutput + + engine = EngineCore.__new__(EngineCore) + engine._output_collectors = {"req-1": RequestOutputCollector()} + engine._stream_states = {} + engine._finished_events = {} + engine.scheduler = MagicMock() + engine.scheduler.abort_request = MagicMock(return_value=True) + engine.scheduler.remove_finished_request = MagicMock() + + engine._output_collectors["req-1"].put( + RequestOutput( + request_id="req-1", + output_text="done", + finished=True, + finish_reason="stop", + ) + ) + + stream = EngineCore.stream_outputs(engine, "req-1") + output = await stream.__anext__() + assert output.finished is True + await stream.aclose() + + engine.scheduler.abort_request.assert_not_called() + engine.scheduler.remove_finished_request.assert_called_once_with("req-1") + + +class TestChunkedPrefillConfig: + """Regression tests for chunked prefill configuration (#178).""" + + def test_prompt_cache_save_installed_without_chunked_prefill(self): + """When chunked_prefill_tokens=0 but memory_aware_cache is active, + _install_prompt_cache_save should still patch _process_prompts.""" + from vllm_mlx.scheduler import _install_prompt_cache_save + + calls = [] + + class FakeBatchGen: + def _process_prompts(self, prompts): + class FakeBatch: + uids = [42] + num_tokens = [0] + + def extract_cache(self, idx): + return f"cache-{idx}" + + return FakeBatch() + + bg = FakeBatchGen() + orig_fn = bg._process_prompts + _install_prompt_cache_save(bg, lambda uid, cache: calls.append((uid, cache))) + + # _process_prompts was patched + assert bg._process_prompts is not orig_fn + bg._process_prompts([]) + assert calls == [(42, "cache-0")], f"Expected callback to fire, got {calls}" + + def test_chunked_prefill_zero_does_not_install_chunked_next(self): + """chunked_prefill_tokens=0 must not install the chunked _next patch, + even when use_memory_aware_cache=True.""" + config = SchedulerConfig(chunked_prefill_tokens=0, use_memory_aware_cache=True) + need_chunked = config.chunked_prefill_tokens > 0 + assert not need_chunked, ( + "chunked_prefill_tokens=0 must not enable the chunked prefill " + "monkey-patch even when use_memory_aware_cache=True" + ) + + def test_chunked_prefill_positive_enables(self): + """Positive chunked_prefill_tokens should enable chunked prefill.""" + config = SchedulerConfig(chunked_prefill_tokens=4096) + need_chunked = config.chunked_prefill_tokens > 0 + assert need_chunked + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_batching_deterministic.py b/tests/test_batching_deterministic.py index 0e6072ce9..d171b2cef 100644 --- a/tests/test_batching_deterministic.py +++ b/tests/test_batching_deterministic.py @@ -7,9 +7,10 @@ """ import asyncio -import pytest import time +import pytest + # Model to use for tests - small model for fast testing TEST_MODEL = "mlx-community/Llama-3.2-1B-Instruct-4bit" @@ -164,7 +165,7 @@ async def test_concurrent_different_prompts(self, model_and_tokenizer): # Run twice to verify determinism all_results = [] - for run in range(2): + for _run in range(2): async with AsyncEngineCore(model, tokenizer, config) as engine: await asyncio.sleep(0.05) @@ -246,6 +247,12 @@ async def get_output(rid): tokens = await asyncio.gather(*[get_output(r) for r in request_ids]) return sum(tokens) + # Warm-up: run once each to compile kernels and prime caches. + # Without this, the first timed run pays one-time compilation + # overhead, causing spurious failures on loaded machines. + await run_sequential() + await run_batched() + # Time sequential start = time.perf_counter() seq_tokens = await run_sequential() @@ -255,8 +262,6 @@ async def get_output(rid): start = time.perf_counter() batch_tokens = await run_batched() batch_time = time.perf_counter() - start - - # Batched should be faster (at least 1.5x) seq_throughput = seq_tokens / seq_time batch_throughput = batch_tokens / batch_time @@ -455,5 +460,52 @@ async def test_multiple_start_stop(self, model_and_tokenizer): break +class TestBatchGeneratorCleanup: + """Test that BatchGenerator is closed promptly on engine stop.""" + + @pytest.mark.anyio + async def test_batch_generator_closed_after_engine_stop(self, model_and_tokenizer): + """BatchGenerator must be closed before engine stop returns. + + If the generator is left alive for GC to collect on a worker thread, + its __del__ → close() → mx.synchronize() can SIGABRT because the + worker thread has no running asyncio event loop. + """ + from vllm_mlx import AsyncEngineCore, SamplingParams + + model, tokenizer = model_and_tokenizer + params = SamplingParams(max_tokens=5, temperature=0.0) + + engine = AsyncEngineCore(model, tokenizer) + await engine.__aenter__() + await asyncio.sleep(0.05) + + # Generate something so the scheduler creates a BatchGenerator + rid = await engine.add_request("Hello", params) + async for out in engine.stream_outputs(rid, timeout=30): + if out.finished: + break + + # The batch generator should exist after generation + bg = engine.engine.scheduler.batch_generator + assert bg is not None, "BatchGenerator should exist after generation" + + # Stop the engine + await engine.__aexit__(None, None, None) + + # After stop, the batch generator must have been closed (set to None) + assert engine.engine.scheduler.batch_generator is None, ( + "BatchGenerator must be None after engine stop to prevent " + "GC-thread __del__ SIGABRT" + ) + + # The old generator's _old_wired_limit should also be None + # (meaning close() was called, making __del__ a no-op) + assert getattr(bg, "_old_wired_limit", None) is None, ( + "BatchGenerator._old_wired_limit must be None after close() " + "so __del__ is safe from any thread" + ) + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) diff --git a/tests/test_bench_serve.py b/tests/test_bench_serve.py new file mode 100644 index 000000000..e5b835220 --- /dev/null +++ b/tests/test_bench_serve.py @@ -0,0 +1,826 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for vllm_mlx.bench_serve — prompt loading and sweep expansion.""" + +import asyncio +import csv +import json +import os +from pathlib import Path + +import pytest + +from vllm_mlx.bench_serve import ( + RESULT_COLUMNS, + BenchServeResult, + SweepConfig, + compute_request_metrics, + compute_summary_stats, + detect_hardware_fingerprint, + expand_sweep, + format_csv, + format_json, + format_sql, + format_table, + load_prompt_set, + parse_health_response, + parse_metrics_text, + parse_sse_line, + parse_status_response, + validate_response, +) + +# --------------------------------------------------------------------------- +# TestPromptLoading +# --------------------------------------------------------------------------- + + +class TestPromptLoading: + """Tests for load_prompt_set().""" + + def test_load_short(self): + prompts = load_prompt_set("short") + assert isinstance(prompts, list) + assert len(prompts) == 5 + for msgs in prompts: + assert isinstance(msgs, list) and msgs + first = msgs[0] + assert first["role"] == "user" + assert "content" in first and len(first["content"]) > 0 + + def test_load_medium(self): + prompts = load_prompt_set("medium") + assert isinstance(prompts, list) + assert len(prompts) == 5 + for msgs in prompts: + assert isinstance(msgs, list) and msgs + assert "role" in msgs[0] + assert "content" in msgs[0] + + def test_load_long(self): + prompts = load_prompt_set("long") + assert isinstance(prompts, list) + assert len(prompts) == 3 + for msgs in prompts: + assert isinstance(msgs, list) and msgs + assert "role" in msgs[0] + assert "content" in msgs[0] + # Long prompts should actually be long (sum across all messages) + assert sum(len(m["content"]) for m in msgs) > 1000 + + def test_load_thinking(self): + prompts = load_prompt_set("thinking") + assert isinstance(prompts, list) + assert len(prompts) == 3 + for msgs in prompts: + assert isinstance(msgs, list) and msgs + assert "role" in msgs[0] + assert "content" in msgs[0] + + def test_all_builtins_return_lists_of_message_lists(self): + for name in ("short", "medium", "long", "thinking"): + prompts = load_prompt_set(name) + assert isinstance(prompts, list), f"{name}: expected list" + assert len(prompts) > 0, f"{name}: expected non-empty list" + for msgs in prompts: + assert ( + isinstance(msgs, list) and msgs + ), f"{name}: items must be non-empty lists" + for m in msgs: + assert isinstance(m, dict), f"{name}: messages must be dicts" + + def test_unknown_name_raises_file_not_found(self): + with pytest.raises(FileNotFoundError): + load_prompt_set("nonexistent_set") + + def test_unknown_path_raises_file_not_found(self): + with pytest.raises(FileNotFoundError): + load_prompt_set("/tmp/does_not_exist_bench_serve_test.json") + + def test_custom_file_flat_format(self, tmp_path: Path): + """Legacy flat format — list of dicts — still works, normalised to list-of-lists.""" + custom = [ + {"role": "user", "content": "Hello, world!"}, + {"role": "user", "content": "What is 2+2?"}, + ] + custom_file = tmp_path / "custom_prompts.json" + custom_file.write_text(json.dumps(custom)) + + loaded = load_prompt_set(str(custom_file)) + assert loaded == [[m] for m in custom] + + def test_custom_file_multi_message_format(self, tmp_path: Path): + """New format — list of message lists — used for system+user scenarios.""" + custom = [ + [ + {"role": "system", "content": "You are a code assistant."}, + {"role": "user", "content": "Hi"}, + ], + [ + {"role": "system", "content": "You are a code assistant."}, + {"role": "user", "content": "What is 2+2?"}, + ], + ] + custom_file = tmp_path / "multi.json" + custom_file.write_text(json.dumps(custom)) + + loaded = load_prompt_set(str(custom_file)) + assert loaded == custom + + def test_custom_file_preserves_extra_fields(self, tmp_path: Path): + payload = [{"role": "user", "content": "test", "extra": 42}] + p = tmp_path / "test.json" + p.write_text(json.dumps(payload)) + + result = load_prompt_set(str(p)) + # Flat format → normalised to [[msg]] + assert result == [[payload[0]]] + + def test_invalid_top_level_raises(self, tmp_path: Path): + p = tmp_path / "bad.json" + p.write_text('{"not": "a list"}') + with pytest.raises(ValueError, match="non-empty JSON list"): + load_prompt_set(str(p)) + + def test_empty_list_raises(self, tmp_path: Path): + p = tmp_path / "empty.json" + p.write_text("[]") + with pytest.raises(ValueError, match="non-empty JSON list"): + load_prompt_set(str(p)) + + def test_invalid_entry_type_raises(self, tmp_path: Path): + p = tmp_path / "weird.json" + p.write_text('["not a dict or list"]') + with pytest.raises(ValueError, match="must be dict or list"): + load_prompt_set(str(p)) + + def test_short_prompts_are_user_role(self): + prompts = load_prompt_set("short") + roles = {msgs[0]["role"] for msgs in prompts} + assert roles == {"user"} + + def test_thinking_prompts_contain_reasoning_keywords(self): + prompts = load_prompt_set("thinking") + combined = " ".join(m["content"] for msgs in prompts for m in msgs).lower() + # At least one reasoning-heavy keyword should appear + keywords = ["step", "deduc", "logic", "proof", "weighing", "clue"] + assert any(kw in combined for kw in keywords) + + +# --------------------------------------------------------------------------- +# TestExpandSweep +# --------------------------------------------------------------------------- + + +class TestExpandSweep: + """Tests for expand_sweep().""" + + def test_single_values_single_repetition(self): + result = expand_sweep(["short"], [1], [None], [""], 1) + assert result == [("short", 1, None, "", 0)] + + def test_two_prompt_sets_one_rep(self): + result = expand_sweep(["short", "long"], [1], [None], [""], 1) + assert len(result) == 2 + prompt_sets = [r[0] for r in result] + assert "short" in prompt_sets + assert "long" in prompt_sets + + def test_combinatorial_2x2x1x1x3(self): + # 2 prompt_sets × 2 concurrencies × 1 thinking × 1 extra_body × 3 reps = 12 + result = expand_sweep( + ["short", "medium"], + [1, 4], + [None], + [""], + 3, + ) + assert len(result) == 12 + + def test_thinking_sweep(self): + # 1 × 1 × 3 thinking × 1 × 1 = 3 + result = expand_sweep(["short"], [1], [None, True, False], [""], 1) + assert len(result) == 3 + thinking_vals = [r[2] for r in result] + assert None in thinking_vals + assert True in thinking_vals + assert False in thinking_vals + + def test_extra_body_sweep(self): + bodies = ["", '{"top_p": 0.9}', '{"top_k": 50}'] + result = expand_sweep(["short"], [1], [None], bodies, 1) + assert len(result) == 3 + extra_vals = [r[3] for r in result] + assert set(extra_vals) == set(bodies) + + def test_repetition_indices_are_zero_based(self): + result = expand_sweep(["short"], [1], [None], [""], 5) + rep_indices = [r[4] for r in result] + assert rep_indices == list(range(5)) + + def test_repetition_indices_with_multiple_combos(self): + result = expand_sweep(["short", "medium"], [1], [None], [""], 3) + # 2 combos × 3 reps = 6 entries; each combo should have reps 0,1,2 + short_reps = [r[4] for r in result if r[0] == "short"] + medium_reps = [r[4] for r in result if r[0] == "medium"] + assert sorted(short_reps) == [0, 1, 2] + assert sorted(medium_reps) == [0, 1, 2] + + def test_full_cartesian_3x2x2x2x2(self): + # 3 × 2 × 2 × 2 × 2 = 48 + result = expand_sweep( + ["short", "medium", "long"], + [1, 4], + [None, True], + ["", '{"top_p": 0.9}'], + 2, + ) + assert len(result) == 48 + + def test_result_elements_are_sweep_config_tuples(self): + result = expand_sweep(["short"], [2], [True], ['{"a":1}'], 1) + assert len(result) == 1 + config = result[0] + assert isinstance(config, tuple) + assert len(config) == 5 + prompt_set, concurrency, thinking, extra_body, rep_idx = config + assert prompt_set == "short" + assert concurrency == 2 + assert thinking is True + assert extra_body == '{"a":1}' + assert rep_idx == 0 + + def test_empty_inputs_return_empty(self): + assert expand_sweep([], [1], [None], [""], 1) == [] + assert expand_sweep(["short"], [], [None], [""], 1) == [] + assert expand_sweep(["short"], [1], [], [""], 1) == [] + assert expand_sweep(["short"], [1], [None], [], 1) == [] + + def test_zero_repetitions_return_empty(self): + result = expand_sweep(["short"], [1], [None], [""], 0) + assert result == [] + + def test_output_type_is_list(self): + result = expand_sweep(["short"], [1], [None], [""], 1) + assert isinstance(result, list) + + def test_all_prompt_sets_appear_in_output(self): + sets = ["short", "medium", "long", "thinking"] + result = expand_sweep(sets, [1], [None], [""], 1) + assert len(result) == 4 + output_sets = {r[0] for r in result} + assert output_sets == set(sets) + + def test_concurrency_values_appear_in_output(self): + result = expand_sweep(["short"], [1, 4, 16], [None], [""], 1) + concurrencies = {r[1] for r in result} + assert concurrencies == {1, 4, 16} + + def test_sweep_config_type_alias(self): + # SweepConfig should be constructable as a 5-tuple + config: SweepConfig = ("short", 1, None, "", 0) + assert config[0] == "short" + assert config[1] == 1 + assert config[2] is None + assert config[3] == "" + assert config[4] == 0 + + +# --------------------------------------------------------------------------- +# TestBenchServeResult +# --------------------------------------------------------------------------- + + +class TestBenchServeResult: + """Basic sanity checks for the BenchServeResult dataclass.""" + + def test_default_instantiation(self): + r = BenchServeResult() + assert r.run_id == "" + assert r.concurrency == 1 + assert r.max_tokens == 256 + assert r.enable_thinking is None + assert r.validated is True + + def test_field_assignment(self): + r = BenchServeResult( + run_id="abc123", + model_id="mlx-community/Llama-3.2-1B-Instruct-4bit", + concurrency=4, + ttft_ms=42.5, + gen_tps=123.4, + cache_hit_rate=0.85, + validated=False, + ) + assert r.run_id == "abc123" + assert r.model_id == "mlx-community/Llama-3.2-1B-Instruct-4bit" + assert r.concurrency == 4 + assert r.ttft_ms == 42.5 + assert r.gen_tps == 123.4 + assert r.cache_hit_rate == 0.85 + assert r.validated is False + + def test_enable_thinking_can_be_none_or_bool(self): + r_none = BenchServeResult(enable_thinking=None) + r_true = BenchServeResult(enable_thinking=True) + r_false = BenchServeResult(enable_thinking=False) + assert r_none.enable_thinking is None + assert r_true.enable_thinking is True + assert r_false.enable_thinking is False + + def test_has_all_required_fields(self): + r = BenchServeResult() + # Identity + assert hasattr(r, "run_id") + assert hasattr(r, "timestamp") + assert hasattr(r, "tag") + # Hardware + assert hasattr(r, "chip") + assert hasattr(r, "gpu_cores") + assert hasattr(r, "memory_gb") + assert hasattr(r, "bandwidth_gbs") + assert hasattr(r, "os_version") + # Runtime + assert hasattr(r, "model_id") + assert hasattr(r, "model_type") + assert hasattr(r, "engine_type") + assert hasattr(r, "mtp_enabled") + assert hasattr(r, "specprefill") + assert hasattr(r, "kv_quant") + assert hasattr(r, "cache_type") + # Config + assert hasattr(r, "prompt_set") + assert hasattr(r, "concurrency") + assert hasattr(r, "max_tokens") + assert hasattr(r, "enable_thinking") + assert hasattr(r, "extra_body") + assert hasattr(r, "repetition") + assert hasattr(r, "prompt_tokens") + # Latency + assert hasattr(r, "ttft_ms") + assert hasattr(r, "tpot_ms") + assert hasattr(r, "e2e_latency_ms") + # Throughput + assert hasattr(r, "gen_tps") + assert hasattr(r, "prompt_tps") + assert hasattr(r, "throughput_tps") + assert hasattr(r, "requests_per_s") + # Memory + assert hasattr(r, "metal_active_gb") + assert hasattr(r, "metal_peak_gb") + assert hasattr(r, "metal_cache_gb") + # Cache + assert hasattr(r, "cache_hits") + assert hasattr(r, "cache_misses") + assert hasattr(r, "cache_hit_rate") + assert hasattr(r, "tokens_saved") + # Validation + assert hasattr(r, "validated") + + +# --------------------------------------------------------------------------- +# TestAutoDetectionParsing (Task 3) +# --------------------------------------------------------------------------- + + +class TestAutoDetectionParsing: + """Unit tests for server response parsers and hardware fingerprint.""" + + def test_parse_health_response(self): + data = { + "status": "healthy", + "model_loaded": True, + "model_name": "mlx-community/Llama-3.2-1B-Instruct-4bit", + "model_type": "llm", + } + result = parse_health_response(data) + assert result["model_name"] == "mlx-community/Llama-3.2-1B-Instruct-4bit" + assert result["model_type"] == "llm" + + def test_parse_health_mllm(self): + data = { + "status": "healthy", + "model_loaded": True, + "model_name": "mlx-community/gemma-4-27b", + "model_type": "mllm", + } + result = parse_health_response(data) + assert result["model_type"] == "mllm" + assert result["model_name"] == "mlx-community/gemma-4-27b" + + def test_parse_status_response(self): + data = { + "model": "mlx-community/Llama-3.2-1B-Instruct-4bit", + "metal": { + "active_gb": 12.5, + "peak_gb": 14.0, + "cache_gb": 2.0, + }, + "cache": {"type": "paged"}, + } + result = parse_status_response(data) + assert result["model"] == "mlx-community/Llama-3.2-1B-Instruct-4bit" + assert result["metal_active_gb"] == pytest.approx(12.5) + assert result["metal_peak_gb"] == pytest.approx(14.0) + assert result["metal_cache_gb"] == pytest.approx(2.0) + assert result["cache_type"] == "paged" + + def test_parse_status_no_metal(self): + data = {"model": "some-model"} + result = parse_status_response(data) + assert result["metal_active_gb"] == pytest.approx(0.0) + assert result["metal_peak_gb"] == pytest.approx(0.0) + assert result["metal_cache_gb"] == pytest.approx(0.0) + assert result["cache_type"] == "" + + def test_parse_metrics_text_with_cache_stats(self): + text = ( + "# HELP vllm_prefix_cache_hits_total Total prefix cache hits\n" + "# TYPE vllm_prefix_cache_hits_total counter\n" + "vllm_prefix_cache_hits_total 42\n" + "# HELP vllm_prefix_cache_misses_total Total prefix cache misses\n" + "# TYPE vllm_prefix_cache_misses_total counter\n" + "vllm_prefix_cache_misses_total 8\n" + "# HELP vllm_prefix_cache_tokens_saved_total Tokens saved\n" + "# TYPE vllm_prefix_cache_tokens_saved_total counter\n" + "vllm_prefix_cache_tokens_saved_total 1024\n" + ) + result = parse_metrics_text(text) + assert result["cache_hits"] == 42 + assert result["cache_misses"] == 8 + assert result["tokens_saved"] == 1024 + + def test_parse_metrics_empty(self): + result = parse_metrics_text("") + assert result["cache_hits"] == 0 + assert result["cache_misses"] == 0 + assert result["tokens_saved"] == 0 + + def test_detect_hardware_fingerprint(self): + result = detect_hardware_fingerprint() + assert isinstance(result, dict) + assert "chip" in result + assert "gpu_cores" in result + assert "memory_gb" in result + assert "bandwidth_gbs" in result + assert "os_version" in result + assert isinstance(result["os_version"], str) + assert len(result["os_version"]) > 0 + + +# --------------------------------------------------------------------------- +# TestSSEParsing (Task 4) +# --------------------------------------------------------------------------- + + +class TestSSEParsing: + """Unit tests for parse_sse_line().""" + + def _make_line(self, delta_content=None, finish_reason=None, usage=None): + """Build a synthetic SSE data line.""" + chunk: dict = { + "choices": [ + { + "delta": ( + {"content": delta_content} if delta_content is not None else {} + ), + "finish_reason": finish_reason, + } + ] + } + if usage is not None: + chunk["usage"] = usage + return f"data: {json.dumps(chunk)}" + + def test_parse_data_line(self): + line = self._make_line(delta_content="Hello") + result = parse_sse_line(line) + assert result is not None + assert result["content"] == "Hello" + assert result["finish_reason"] is None + assert result["usage"] is None + + def test_parse_done(self): + assert parse_sse_line("data: [DONE]") is None + + def test_parse_empty_line(self): + assert parse_sse_line("") is None + + def test_parse_comment_line(self): + assert parse_sse_line(": keep-alive") is None + + def test_parse_no_content(self): + line = self._make_line() # delta has no content key + result = parse_sse_line(line) + assert result is not None + assert result["content"] == "" + + def test_parse_with_usage(self): + usage = {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + line = self._make_line(delta_content=None, usage=usage) + result = parse_sse_line(line) + assert result is not None + assert result["usage"] == usage + + def test_parse_with_finish_reason(self): + line = self._make_line(finish_reason="stop") + result = parse_sse_line(line) + assert result is not None + assert result["finish_reason"] == "stop" + + +# --------------------------------------------------------------------------- +# TestRequestMetrics (Task 4) +# --------------------------------------------------------------------------- + + +class TestRequestMetrics: + """Unit tests for compute_request_metrics().""" + + def test_basic_metrics(self): + # Simulate: request starts at 0, first token at 0.1s, then 4 more tokens + # every 0.02s. + t_start = 0.0 + t_first = 0.1 + token_times = [t_first + i * 0.02 for i in range(5)] + t_end = token_times[-1] + 0.001 # tiny extra after last token + prompt_tokens = 20 + completion_tokens = 5 + + metrics = compute_request_metrics( + t_start=t_start, + t_first_token=t_first, + token_times=token_times, + t_end=t_end, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + + # TTFT should be ~100 ms + assert metrics["ttft_ms"] == pytest.approx(100.0, abs=1.0) + # TPOT should be ~20 ms (inter-token interval) + assert metrics["tpot_ms"] == pytest.approx(20.0, abs=1.0) + # E2E latency > TTFT + assert metrics["e2e_latency_ms"] > metrics["ttft_ms"] + # gen_tps > 0 + assert metrics["gen_tps"] > 0.0 + # prompt_tps > 0 + assert metrics["prompt_tps"] > 0.0 + + def test_single_token(self): + t_start = 0.0 + t_first = 0.05 + token_times = [t_first] + t_end = t_first + 0.001 + metrics = compute_request_metrics( + t_start=t_start, + t_first_token=t_first, + token_times=token_times, + t_end=t_end, + prompt_tokens=10, + completion_tokens=1, + ) + # Single token → no inter-token interval → TPOT = 0.0 + assert metrics["tpot_ms"] == pytest.approx(0.0) + + +# --------------------------------------------------------------------------- +# TestValidation (Task 5) +# --------------------------------------------------------------------------- + + +class TestValidation: + """Unit tests for validate_response().""" + + def test_valid_response(self): + is_valid, msg = validate_response( + finish_reason="stop", content="Hello world", status_code=200 + ) + assert is_valid is True + assert msg == "" + + def test_empty_content(self): + is_valid, msg = validate_response( + finish_reason="stop", content="", status_code=200 + ) + assert is_valid is False + assert "empty" in msg.lower() + + def test_length_truncation(self): + is_valid, msg = validate_response( + finish_reason="length", content="partial text", status_code=200 + ) + assert is_valid is False + assert "length" in msg.lower() + + def test_missing_finish_reason(self): + is_valid, msg = validate_response( + finish_reason=None, content="some text", status_code=200 + ) + assert is_valid is False + assert msg != "" + + def test_http_error(self): + is_valid, msg = validate_response( + finish_reason="stop", content="error body", status_code=500 + ) + assert is_valid is False + assert "500" in msg + + +# --------------------------------------------------------------------------- +# TestSummaryStats (Task 5) +# --------------------------------------------------------------------------- + + +class TestSummaryStats: + """Unit tests for compute_summary_stats().""" + + def test_basic_summary(self): + stats = compute_summary_stats([10.0, 20.0, 30.0, 40.0, 50.0]) + assert stats["mean"] == pytest.approx(30.0) + assert stats["min"] == pytest.approx(10.0) + assert stats["max"] == pytest.approx(50.0) + assert stats["p50"] == pytest.approx(30.0) + + def test_single_value(self): + stats = compute_summary_stats([42.0]) + assert stats["mean"] == pytest.approx(42.0) + assert stats["stddev"] == pytest.approx(0.0) + assert stats["min"] == pytest.approx(42.0) + assert stats["max"] == pytest.approx(42.0) + assert stats["p50"] == pytest.approx(42.0) + assert stats["p95"] == pytest.approx(42.0) + assert stats["p99"] == pytest.approx(42.0) + + def test_empty_raises(self): + with pytest.raises(ValueError): + compute_summary_stats([]) + + +# --------------------------------------------------------------------------- +# Formatter helpers (Task 6) +# --------------------------------------------------------------------------- + + +def _make_sample_result(**overrides) -> BenchServeResult: + """Return a BenchServeResult with realistic defaults, accepting overrides.""" + defaults = dict( + run_id="run-abc123", + timestamp="2026-04-17T10:00:00Z", + tag="ci", + chip="Apple M3 Max", + gpu_cores=40, + memory_gb=128.0, + bandwidth_gbs=400.0, + os_version="macOS 15.4", + model_id="mlx-community/gemma-3-4b-it-4bit", + model_type="llm", + engine_type="vllm-mlx", + mtp_enabled=False, + specprefill=False, + kv_quant="", + cache_type="paged", + prompt_set="short", + concurrency=4, + max_tokens=256, + enable_thinking=None, + extra_body="", + repetition=0, + prompt_tokens=32, + ttft_ms=85.3, + tpot_ms=12.4, + e2e_latency_ms=420.7, + gen_tps=80.6, + prompt_tps=310.2, + throughput_tps=75.1, + requests_per_s=2.4, + metal_active_gb=12.5, + metal_peak_gb=14.0, + metal_cache_gb=2.0, + cache_hits=10, + cache_misses=2, + cache_hit_rate=0.833, + tokens_saved=320, + validated=True, + ) + defaults.update(overrides) + return BenchServeResult(**defaults) + + +# --------------------------------------------------------------------------- +# TestFormatters (Task 6) +# --------------------------------------------------------------------------- + + +class TestFormatters: + """Unit tests for output formatter functions.""" + + def test_format_table_not_empty(self): + r = _make_sample_result() + output = format_table([r]) + assert len(output) > 0 + assert r.model_id in output or str(r.ttft_ms) in output or "85" in output + + def test_format_json_roundtrip(self): + r1 = _make_sample_result(run_id="r1") + r2 = _make_sample_result(run_id="r2", concurrency=8) + output = format_json([r1, r2]) + parsed = json.loads(output) + assert len(parsed) == 2 + assert parsed[0]["run_id"] == "r1" + assert parsed[1]["run_id"] == "r2" + + def test_format_csv_parseable(self): + r = _make_sample_result() + output = format_csv([r]) + reader = csv.DictReader(output.splitlines()) + rows = list(reader) + assert len(rows) == 1 + assert "model_id" in rows[0] + assert "ttft_ms" in rows[0] + assert rows[0]["model_id"] == r.model_id + + def test_format_sql_valid(self): + r = _make_sample_result() + output = format_sql([r]) + assert "CREATE TABLE" in output + assert "INSERT" in output + assert "bench_serve" in output + + def test_format_sql_escapes_quotes(self): + r = _make_sample_result(tag="it's a test") + output = format_sql([r]) + assert "it''s a test" in output + + def test_format_sql_handles_nan_inf(self): + r = _make_sample_result(ttft_ms=float("nan"), gen_tps=float("inf")) + output = format_sql([r]) + # NaN and Inf should become NULL, not invalid SQL literals + assert "nan" not in output.lower().split("'")[-1] # not outside strings + assert "inf" not in output.lower().split("'")[-1] + assert "NULL" in output + + def test_result_columns_match_dataclass(self): + import dataclasses + + field_names = {f.name for f in dataclasses.fields(BenchServeResult)} + assert set(RESULT_COLUMNS) == field_names + + +# --------------------------------------------------------------------------- +# TestBenchServeIntegration (Task 8) +# --------------------------------------------------------------------------- + + +BENCH_SERVE_URL = os.environ.get("BENCH_SERVE_TEST_URL") + + +@pytest.mark.skipif( + BENCH_SERVE_URL is None, + reason="Set BENCH_SERVE_TEST_URL to run integration tests", +) +class TestBenchServeIntegration: + """Integration tests requiring a running vllm-mlx server.""" + + def test_smoke_run(self): + """End-to-end: run bench-serve with minimal config against a real server.""" + from vllm_mlx.bench_serve import run_bench_serve + + results = asyncio.run( + run_bench_serve( + url=BENCH_SERVE_URL, + prompt_sets=["short"], + concurrencies=[1], + repetitions=1, + warmup=0, + max_tokens=32, + fmt="json", + scrape=False, + ) + ) + assert len(results) == 1 + r = results[0] + assert r.ttft_ms > 0 + assert r.gen_tps > 0 + assert r.model_id != "" + assert r.validated is True + + def test_sql_output_is_valid(self): + """Verify SQL output contains CREATE TABLE and INSERT.""" + from vllm_mlx.bench_serve import run_bench_serve, format_sql + + results = asyncio.run( + run_bench_serve( + url=BENCH_SERVE_URL, + prompt_sets=["short"], + concurrencies=[1], + repetitions=1, + warmup=0, + max_tokens=16, + fmt="table", + scrape=False, + ) + ) + sql = format_sql(results) + assert "CREATE TABLE IF NOT EXISTS bench_serve" in sql + assert "INSERT INTO bench_serve" in sql diff --git a/tests/test_chat_template_kwargs.py b/tests/test_chat_template_kwargs.py new file mode 100644 index 000000000..94ee0af98 --- /dev/null +++ b/tests/test_chat_template_kwargs.py @@ -0,0 +1,304 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for chat template kwargs forwarding.""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +import vllm_mlx.server as srv +from vllm_mlx.engine.base import GenerationOutput + + +@pytest.fixture +def anyio_backend(): + return "asyncio" + + +def test_chat_completion_request_preserves_chat_template_kwargs(): + request = srv.ChatCompletionRequest( + model="test-model", + messages=[srv.Message(role="user", content="Hello")], + chat_template_kwargs={"enable_thinking": False}, + ) + + assert request.chat_template_kwargs == {"enable_thinking": False} + + +def test_batched_engine_applies_chat_template_kwargs(): + with patch("vllm_mlx.engine.batched.is_mllm_model", return_value=False): + from vllm_mlx.engine.batched import BatchedEngine + + engine = BatchedEngine("test-model") + engine._tokenizer = MagicMock() + engine._tokenizer.apply_chat_template.return_value = "prompt" + + prompt = engine._apply_chat_template( + [{"role": "user", "content": "Hello"}], + chat_template_kwargs={"enable_thinking": False}, + ) + + assert prompt == "prompt" + engine._tokenizer.apply_chat_template.assert_called_once() + assert ( + engine._tokenizer.apply_chat_template.call_args.kwargs["enable_thinking"] + is False + ) + + +def test_batched_engine_mllm_falls_back_to_tokenizer_when_processor_has_no_template(): + with patch("vllm_mlx.engine.batched.is_mllm_model", return_value=True): + from vllm_mlx.engine.batched import BatchedEngine + + engine = BatchedEngine("test-mllm-model") + engine._is_mllm = True + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = "prompt-from-tokenizer" + + processor = MagicMock() + processor.tokenizer = tokenizer + processor.apply_chat_template.side_effect = ValueError( + "Cannot use apply_chat_template because this processor does not have a chat template." + ) + engine._processor = processor + + prompt = engine._apply_chat_template( + [{"role": "user", "content": "Hello"}], + chat_template_kwargs={"enable_thinking": False}, + ) + + assert prompt == "prompt-from-tokenizer" + processor.apply_chat_template.assert_called_once() + tokenizer.apply_chat_template.assert_called_once() + + +def test_chat_completion_endpoint_forwards_chat_template_kwargs(): + captured = {} + + class FakeEngine: + model_name = "test-model" + is_mllm = False + preserve_native_tool_format = False + + async def chat(self, messages, **kwargs): + captured["messages"] = messages + captured["kwargs"] = kwargs + return GenerationOutput( + text="ORBIT", + prompt_tokens=4, + completion_tokens=1, + finish_reason="stop", + ) + + client = TestClient(srv.app) + original_engine = srv._engine + original_model_name = srv._model_name + srv._engine = FakeEngine() + srv._model_name = "test-model" + try: + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Reply with ORBIT."}], + "max_tokens": 8, + "chat_template_kwargs": {"enable_thinking": False}, + }, + ) + finally: + srv._engine = original_engine + srv._model_name = original_model_name + + assert response.status_code == 200 + assert captured["kwargs"]["chat_template_kwargs"] == {"enable_thinking": False} + assert response.json()["choices"][0]["message"]["content"] == "ORBIT" + + +def test_chat_completion_endpoint_applies_server_default_chat_template_kwargs(): + captured = {} + + class FakeEngine: + model_name = "test-model" + is_mllm = False + preserve_native_tool_format = False + + async def chat(self, messages, **kwargs): + captured["messages"] = messages + captured["kwargs"] = kwargs + return GenerationOutput( + text="ORBIT", + prompt_tokens=4, + completion_tokens=1, + finish_reason="stop", + ) + + client = TestClient(srv.app) + original_engine = srv._engine + original_model_name = srv._model_name + original_defaults = getattr(srv, "_default_chat_template_kwargs", None) + srv._engine = FakeEngine() + srv._model_name = "test-model" + srv._default_chat_template_kwargs = {"enable_thinking": False} + try: + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Reply with ORBIT."}], + "max_tokens": 8, + }, + ) + finally: + srv._engine = original_engine + srv._model_name = original_model_name + srv._default_chat_template_kwargs = original_defaults + + assert response.status_code == 200 + assert captured["kwargs"]["chat_template_kwargs"] == {"enable_thinking": False} + assert response.json()["choices"][0]["message"]["content"] == "ORBIT" + + +def test_chat_completion_endpoint_request_kwargs_override_server_defaults(): + captured = {} + + class FakeEngine: + model_name = "test-model" + is_mllm = False + preserve_native_tool_format = False + + async def chat(self, messages, **kwargs): + captured["messages"] = messages + captured["kwargs"] = kwargs + return GenerationOutput( + text="ORBIT", + prompt_tokens=4, + completion_tokens=1, + finish_reason="stop", + ) + + client = TestClient(srv.app) + original_engine = srv._engine + original_model_name = srv._model_name + original_defaults = getattr(srv, "_default_chat_template_kwargs", None) + srv._engine = FakeEngine() + srv._model_name = "test-model" + srv._default_chat_template_kwargs = { + "enable_thinking": False, + "server_default_only": "yes", + } + try: + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Reply with ORBIT."}], + "max_tokens": 8, + "chat_template_kwargs": { + "enable_thinking": True, + "request_only": 1, + }, + }, + ) + finally: + srv._engine = original_engine + srv._model_name = original_model_name + srv._default_chat_template_kwargs = original_defaults + + assert response.status_code == 200 + assert captured["kwargs"]["chat_template_kwargs"] == { + "enable_thinking": True, + "server_default_only": "yes", + "request_only": 1, + } + assert response.json()["choices"][0]["message"]["content"] == "ORBIT" + + +def test_llm_chat_applies_chat_template_kwargs_before_generate(): + from vllm_mlx.models.llm import MLXLanguageModel + + model = MLXLanguageModel.__new__(MLXLanguageModel) + model._loaded = True + model.tokenizer = MagicMock() + model.tokenizer.apply_chat_template.return_value = "prompt" + model.generate = MagicMock(return_value="ok") + + result = model.chat( + [{"role": "user", "content": "Hello"}], + chat_template_kwargs={"enable_thinking": False}, + ) + + assert result == "ok" + model.tokenizer.apply_chat_template.assert_called_once() + assert ( + model.tokenizer.apply_chat_template.call_args.kwargs["enable_thinking"] is False + ) + model.generate.assert_called_once() + + +@pytest.mark.anyio +async def test_simple_engine_mllm_chat_forwards_chat_template_kwargs(): + from vllm_mlx.engine.simple import SimpleEngine + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=True): + engine = SimpleEngine("test-model") + engine._loaded = True + engine._is_mllm = True + engine._model = MagicMock() + engine._model.chat = MagicMock( + return_value=SimpleNamespace( + text="OK", + prompt_tokens=5, + completion_tokens=1, + finish_reason="stop", + ) + ) + + await engine.chat( + [{"role": "user", "content": "Hello"}], + chat_template_kwargs={"enable_thinking": False}, + ) + + assert engine._model.chat.call_args.kwargs["chat_template_kwargs"] == { + "enable_thinking": False + } + + +@pytest.mark.anyio +async def test_simple_engine_stream_generate_text_applies_chat_template_kwargs(): + from vllm_mlx.engine.simple import SimpleEngine + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=True): + engine = SimpleEngine("test-model") + engine._loaded = True + engine._is_mllm = True + engine._text_tokenizer = MagicMock() + engine._text_tokenizer.apply_chat_template.return_value = "prompt" + engine._text_model = MagicMock() + engine._text_model.model = MagicMock() + + with ( + patch("mlx_lm.stream_generate", return_value=iter(())), + patch("mlx_lm.models.cache.make_prompt_cache", return_value=[]), + patch("mlx_lm.sample_utils.make_sampler", return_value=object()), + ): + chunks = [ + chunk + async for chunk in engine._stream_generate_text( + [{"role": "user", "content": "Hello"}], + max_tokens=8, + temperature=0.7, + top_p=0.9, + chat_template_kwargs={"enable_thinking": False}, + ) + ] + + assert chunks + engine._text_tokenizer.apply_chat_template.assert_called_once() + assert ( + engine._text_tokenizer.apply_chat_template.call_args.kwargs[ + "enable_thinking" + ] + is False + ) diff --git a/tests/test_constrained_decoding.py b/tests/test_constrained_decoding.py new file mode 100644 index 000000000..099872418 --- /dev/null +++ b/tests/test_constrained_decoding.py @@ -0,0 +1,619 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Unit tests for JSON-schema constrained decoding. + +Covers: + +* ``vllm_mlx.constrained.cache`` — tokenizer-data construction and caching. +* ``vllm_mlx.constrained.json_schema_processor.JSONSchemaLogitsProcessor`` + — mask shape, token filtering, schema enforcement. +* ``vllm_mlx.api.tool_calling.build_json_logits_processor`` — builder glue. +* ``vllm_mlx.api.anthropic_adapter.anthropic_to_openai`` — propagation of + the OpenAI-compatible ``response_format`` extension field from + ``/v1/messages`` requests. + +These tests are pure-logic: they use a minimal fake tokenizer (ASCII +vocabulary) to avoid pulling in a 200k-entry HF tokenizer at test time. +They skip gracefully when ``lm-format-enforcer`` is not installed, so CI +running with a minimal dependency set still passes. +""" + +from __future__ import annotations + +import pytest + +from vllm_mlx.api.anthropic_adapter import anthropic_to_openai +from vllm_mlx.api.anthropic_models import AnthropicMessage, AnthropicRequest +from vllm_mlx.api.models import ResponseFormat, ResponseFormatJsonSchema +from vllm_mlx.api.tool_calling import build_json_logits_processor +from vllm_mlx.constrained import is_available +from vllm_mlx.constrained.cache import clear_cache, get_tokenizer_data + +# Skip all processor-level tests if the optional dependency is missing. +pytestmark_lmfe = pytest.mark.skipif( + not is_available(), + reason="lm-format-enforcer not installed", +) + + +# --------------------------------------------------------------------------- +# Fake tokenizer — just enough API surface for the cache + processor. +# --------------------------------------------------------------------------- + + +class _FakeTokenizer: + """ + Deterministic, tiny tokenizer for unit tests. + + Vocabulary is a small ASCII set (digits, brackets, whitespace, quotes, + a handful of letters). IDs start at 0 and are stable across instances. + Special tokens (EOS/BOS) occupy the end of the range. + """ + + def __init__(self) -> None: + chars = list('0123456789{}[]:," \n\ttrueflasnul') + # Deduplicate while preserving order. + seen: set[str] = set() + unique: list[str] = [] + for c in chars: + if c not in seen: + seen.add(c) + unique.append(c) + self._id_to_tok = unique + ["", ""] + self._tok_to_id = {t: i for i, t in enumerate(self._id_to_tok)} + self.vocab_size = len(self._id_to_tok) + self.eos_token_id = self._tok_to_id[""] + self.all_special_ids = [ + self._tok_to_id[""], + self._tok_to_id[""], + ] + + def __len__(self) -> int: + return self.vocab_size + + def encode(self, text: str) -> list[int]: + out: list[int] = [] + for ch in text: + if ch in self._tok_to_id: + out.append(self._tok_to_id[ch]) + return out + + def decode(self, ids: list[int]) -> str: + parts: list[str] = [] + for i in ids: + if 0 <= i < len(self._id_to_tok): + tok = self._id_to_tok[i] + if not tok.startswith("<"): + parts.append(tok) + return "".join(parts) + + def get_vocab(self) -> dict[str, int]: + return dict(self._tok_to_id) + + +@pytest.fixture(autouse=True) +def _reset_cache(): + clear_cache() + yield + clear_cache() + + +# --------------------------------------------------------------------------- +# Tokenizer-data cache. +# --------------------------------------------------------------------------- + + +@pytestmark_lmfe +class TestTokenizerDataCache: + def test_builds_data_for_fake_tokenizer(self): + tok = _FakeTokenizer() + data = get_tokenizer_data(tok) + assert data is not None + assert data.vocab_size == tok.vocab_size + + def test_cache_reuse_same_tokenizer(self): + tok = _FakeTokenizer() + first = get_tokenizer_data(tok) + second = get_tokenizer_data(tok) + assert first is second # cached object returned verbatim + + def test_separate_tokenizers_get_separate_entries(self): + a = _FakeTokenizer() + b = _FakeTokenizer() + assert get_tokenizer_data(a) is not get_tokenizer_data(b) + + +# --------------------------------------------------------------------------- +# build_json_logits_processor() — builder glue. +# --------------------------------------------------------------------------- + + +class TestBuildJsonLogitsProcessor: + """These tests don't require lm-format-enforcer — they just check glue.""" + + def test_text_returns_none(self): + tok = _FakeTokenizer() + result = build_json_logits_processor({"type": "text"}, tok) + assert result is None + + def test_none_returns_none(self): + tok = _FakeTokenizer() + result = build_json_logits_processor(None, tok) + assert result is None + + def test_unsupported_type_returns_none(self): + tok = _FakeTokenizer() + result = build_json_logits_processor({"type": "xml"}, tok) + assert result is None + + @pytestmark_lmfe + def test_json_object_builds_processor(self): + tok = _FakeTokenizer() + result = build_json_logits_processor({"type": "json_object"}, tok) + assert result is not None + + @pytestmark_lmfe + def test_json_schema_builds_processor(self): + tok = _FakeTokenizer() + response_format = { + "type": "json_schema", + "json_schema": { + "name": "person", + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + }, + }, + } + result = build_json_logits_processor(response_format, tok) + assert result is not None + + @pytestmark_lmfe + def test_json_schema_pydantic_model(self): + tok = _FakeTokenizer() + response_format = ResponseFormat( + type="json_schema", + json_schema=ResponseFormatJsonSchema( + name="result", + schema={"type": "object", "properties": {"ok": {"type": "boolean"}}}, + ), + ) + result = build_json_logits_processor(response_format, tok) + assert result is not None + + +# --------------------------------------------------------------------------- +# JSONSchemaLogitsProcessor — mask semantics. +# --------------------------------------------------------------------------- + + +@pytestmark_lmfe +class TestProcessorMask: + def _make_processor(self, schema: dict | None = None): + from vllm_mlx.constrained import JSONSchemaLogitsProcessor + + tok = _FakeTokenizer() + return JSONSchemaLogitsProcessor(schema, tok), tok + + def test_mask_shape_matches_logits(self): + import mlx.core as mx + + processor, tok = self._make_processor() + # Emulate: one prompt token already consumed, cursor at first + # generation step (tokens contains prompt+1 generated). + prompt = tok.encode(" ") # a single harmless token as "prompt" + tokens = mx.array(prompt) + logits = mx.zeros((tok.vocab_size,)) + masked = processor(tokens, logits) + assert masked.shape == logits.shape + + def test_mask_shape_matches_2d_logits(self): + import mlx.core as mx + + processor, tok = self._make_processor() + prompt = tok.encode(" ") + tokens = mx.array(prompt) + logits = mx.zeros((1, tok.vocab_size)) + masked = processor(tokens, logits) + assert masked.shape == logits.shape + + def test_allows_at_least_one_token_at_start(self): + """At the start of a JSON value, at least ``{`` or ``[`` must pass.""" + import mlx.core as mx + + processor, tok = self._make_processor() + prompt = tok.encode(" ") + tokens = mx.array(prompt) + logits = mx.zeros((tok.vocab_size,)) + masked = processor(tokens, logits) + # At least one finite entry must remain. + finite = (masked != -float("inf")).sum().item() + assert finite >= 1 + + def test_processor_never_crashes_on_arbitrary_state(self): + """Defensive: even with unexpected token history, returns logits.""" + import mlx.core as mx + + processor, _ = self._make_processor() + # Pass nonsensical tokens — the processor's except handler should + # catch any parser error and return the original logits unchanged. + tokens = mx.array([999999]) # out-of-vocab id + logits = mx.zeros((8,)) + result = processor(tokens, logits) + assert result.shape == logits.shape + + +# --------------------------------------------------------------------------- +# Anthropic adapter — response_format propagation. +# --------------------------------------------------------------------------- + + +class TestAnthropicAdapterResponseFormat: + def test_response_format_propagates_dict(self): + req = AnthropicRequest( + model="test", + messages=[AnthropicMessage(role="user", content="Hi")], + max_tokens=50, + response_format={"type": "json_object"}, + ) + openai_req = anthropic_to_openai(req) + assert openai_req.response_format is not None + # ChatCompletionRequest coerces the dict into ResponseFormat. + rf_type = ( + openai_req.response_format.type + if hasattr(openai_req.response_format, "type") + else openai_req.response_format.get("type") + ) + assert rf_type == "json_object" + + def test_response_format_json_schema_propagates(self): + schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + } + req = AnthropicRequest( + model="test", + messages=[AnthropicMessage(role="user", content="Give a name")], + max_tokens=50, + response_format={ + "type": "json_schema", + "json_schema": {"name": "person", "schema": schema}, + }, + ) + openai_req = anthropic_to_openai(req) + assert openai_req.response_format is not None + rf = openai_req.response_format + rf_type = rf.type if hasattr(rf, "type") else rf.get("type") + assert rf_type == "json_schema" + + def test_missing_response_format_is_none(self): + req = AnthropicRequest( + model="test", + messages=[AnthropicMessage(role="user", content="Hi")], + max_tokens=50, + ) + openai_req = anthropic_to_openai(req) + assert openai_req.response_format is None + + +# --------------------------------------------------------------------------- +# _simplify_schema — metadata stripping and anyOf flattening. +# --------------------------------------------------------------------------- + + +class TestSimplifySchema: + """Test ``_simplify_schema`` handles metadata and nested anyOf correctly.""" + + def test_strips_default_and_metadata(self): + from vllm_mlx.constrained.json_schema_processor import _simplify_schema + + schema = { + "type": "object", + "properties": { + "name": { + "type": "string", + "default": "unknown", + "title": "Name", + "description": "The name", + "examples": ["Alice"], + }, + }, + } + result = _simplify_schema(schema) + name_prop = result["properties"]["name"] + for kw in ("default", "title", "description", "examples"): + assert kw not in name_prop, f"{kw!r} should have been stripped" + assert name_prop["type"] == "string" + + def test_flattens_nested_anyof(self): + from vllm_mlx.constrained.json_schema_processor import _simplify_schema + + schema = { + "anyOf": [ + {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + {"type": "null"}, + ] + } + result = _simplify_schema(schema) + # Nested anyOf should be flattened to 3 branches. + assert "anyOf" in result + assert len(result["anyOf"]) == 3 + + def test_fact_batch_schema_simplifies_cleanly(self): + """Regression test: the ``fact_batch`` schema that caused enforcer + to get stuck in production (nested anyOf + default + $ref + not).""" + from vllm_mlx.constrained.json_schema_processor import _simplify_schema + + fact_batch_schema = { + "type": "object", + "properties": { + "facts": { + "anyOf": [ + { + "anyOf": [ + {"not": {"$ref": "#/definitions/OpenAiAnyType"}}, + { + "type": "array", + "items": { + "type": "object", + "properties": { + "kind": { + "type": "string", + "enum": [ + "number", + "price", + "product_name", + ], + }, + "value": { + "type": "string", + "minLength": 1, + "maxLength": 300, + }, + "confidence": { + "type": "string", + "enum": [ + "high", + "medium", + "low", + ], + }, + }, + "required": ["kind", "value", "confidence"], + "additionalProperties": False, + }, + "maxItems": 6, + }, + ], + "default": [], + }, + {"type": "null"}, + ] + } + }, + "required": ["facts"], + "additionalProperties": False, + "definitions": { + "OpenAiAnyType": { + "type": ["string", "number", "integer", "boolean", "array", "null"], + "items": {"$ref": "#/definitions/OpenAiAnyType"}, + } + }, + "$schema": "https://json-schema.org/draft/2019-09/schema#", + } + + result = _simplify_schema(fact_batch_schema) + + # $schema must be stripped. + assert "$schema" not in result + # definitions must be consumed. + assert "definitions" not in result + + # The ``facts`` property must exist and contain a flat anyOf. + facts = result["properties"]["facts"] + assert "anyOf" in facts + # ``default`` must be stripped from all levels. + assert "default" not in facts + for branch in facts["anyOf"]: + assert "default" not in branch + # No nested anyOf wrappers should remain — each branch is either + # the array schema or ``{type: null}``. + for branch in facts["anyOf"]: + if "anyOf" in branch: + # Inner anyOf should have been flattened into the outer one. + pytest.fail(f"Nested anyOf still present in branch: {branch!r}") + + +# --------------------------------------------------------------------------- +# Incremental caching — O(n) suffix decode + context tracking. +# --------------------------------------------------------------------------- + + +class _BPETokenizer(_FakeTokenizer): + """Fake tokenizer that simulates BPE whitespace-prefix behaviour. + + In real BPE tokenizers (Mistral, Llama, Qwen), a token like ``world`` + decodes to ``" world"`` (with leading space) when preceded by another + token, but ``"world"`` when decoded alone. This means:: + + decode([tok_hello]) + decode([tok_world]) != decode([tok_hello, tok_world]) + + This tokenizer adds multi-character tokens with context-dependent + whitespace to exercise prefix-stability in ``_decode_suffix``. + """ + + def __init__(self) -> None: + super().__init__() + # Add multi-char tokens that simulate BPE merges. + # Token ids continue from the parent's vocab. + self._multi = { + self.vocab_size: "Hello", + self.vocab_size + 1: " world", # note: leading space in context + self.vocab_size + 2: " value", + } + self._multi_alone = { + self.vocab_size: "Hello", + self.vocab_size + 1: "world", # NO leading space when decoded alone + self.vocab_size + 2: "value", + } + self.vocab_size += len(self._multi) + + def encode(self, text: str) -> list[int]: + # Simple: just use parent for JSON chars, multi tokens for words. + return super().encode(text) + + def decode(self, ids: list[int]) -> str: + parts: list[str] = [] + for idx, tok_id in enumerate(ids): + if tok_id in self._multi: + # Simulate BPE: leading space only when preceded by another token. + if idx > 0: + parts.append(self._multi[tok_id]) + else: + parts.append(self._multi_alone[tok_id]) + elif 0 <= tok_id < len(self._id_to_tok): + tok = self._id_to_tok[tok_id] + if not tok.startswith("<"): + parts.append(tok) + return "".join(parts) + + +@pytestmark_lmfe +class TestIncrementalCaching: + """Verify that the incremental decode / context-tracking optimisations + produce correct results identical to a full re-scan on every step.""" + + def _make_processor(self, schema: dict | None = None, tokenizer=None): + from vllm_mlx.constrained import JSONSchemaLogitsProcessor + + tok = tokenizer or _FakeTokenizer() + return JSONSchemaLogitsProcessor(schema, tok), tok + + def test_incremental_decode_matches_full_decode(self): + """Growing suffix one token at a time should produce the same decoded + text as a full decode on each step.""" + processor, tok = self._make_processor() + text = '{"a": 1}' + token_ids = tok.encode(text) + + for step in range(1, len(token_ids) + 1): + suffix = token_ids[:step] + # Incremental path (cached). + inc_text = processor._decode_suffix(suffix) + # Full decode for reference. + full_text = tok.decode(suffix) + assert ( + inc_text == full_text + ), f"Step {step}: incremental={inc_text!r} != full={full_text!r}" + + def test_non_concatenative_tokenizer_decode(self): + """Regression test: BPE tokenizers where per-token decode differs + from full-sequence decode (whitespace as token prefix). + + decode([tok_hello]) + decode([tok_world]) = "Helloworld" + decode([tok_hello, tok_world]) = "Hello world" + + _decode_suffix must always match the full decode, never per-token concat. + """ + bpe_tok = _BPETokenizer() + processor, _ = self._make_processor(tokenizer=bpe_tok) + + # Token ids for the multi-char tokens. + hello_id = bpe_tok.vocab_size - 3 # "Hello" + world_id = bpe_tok.vocab_size - 2 # " world" in context + + # Verify the tokenizer IS non-concatenative. + alone_concat = bpe_tok.decode([hello_id]) + bpe_tok.decode([world_id]) + full_decode = bpe_tok.decode([hello_id, world_id]) + assert alone_concat == "Helloworld" + assert full_decode == "Hello world" + assert alone_concat != full_decode, "Test tokenizer must be non-concatenative" + + # Step through: _decode_suffix must match full decode at every step. + suffix = [hello_id] + result1 = processor._decode_suffix(suffix) + assert result1 == bpe_tok.decode([hello_id]) + + suffix = [hello_id, world_id] + result2 = processor._decode_suffix(suffix) + assert result2 == full_decode, ( + f"_decode_suffix returned {result2!r}, expected {full_decode!r}. " + f"Per-token concat would give {alone_concat!r} — this is the bug." + ) + + def test_incremental_context_tracks_braces(self): + """Bracket/brace depth should update correctly through incremental + scanning, matching a full re-scan.""" + processor, tok = self._make_processor( + {"type": "object", "properties": {"a": {"type": "integer"}}} + ) + text = '{"a": 1}' + token_ids = tok.encode(text) + + # Simulate stepping through generation. + processor._prompt_len = 0 + for step in range(1, len(token_ids) + 1): + suffix = token_ids[:step] + ctx = processor._get_json_context(suffix) + decoded = tok.decode(suffix) + + # Verify brace depth at key points. + if decoded == "{": + assert processor._brace_depth == 1 + if decoded == '{"a": 1}': + assert processor._brace_depth == 0 + assert ctx == "other" + + def test_bracket_precheck_avoids_json_loads(self): + """When brackets are unbalanced, _suffix_is_complete_json should + return False without calling json.loads (fast pre-check).""" + processor, tok = self._make_processor() + # Feed partial JSON — braces are open. + partial = '{"a": ' + token_ids = tok.encode(partial) + + processor._prompt_len = 0 + # Update context state. + processor._get_json_context(token_ids) + assert processor._brace_depth > 0 + + # Pre-check should short-circuit. + assert not processor._suffix_is_complete_json(token_ids) + + def test_complete_json_detected(self): + processor, tok = self._make_processor() + text = '{"a": 1}' + token_ids = tok.encode(text) + + processor._prompt_len = 0 + # Must update context first (populates bracket counters). + processor._get_json_context(token_ids) + assert processor._brace_depth == 0 + assert processor._suffix_is_complete_json(token_ids) + + def test_numpy_mask_matches_original(self): + """The numpy-based _build_allow_mask should produce the same mask + as the old Python-list approach.""" + import numpy as np + + processor, tok = self._make_processor() + allowed = [0, 3, 7, 10] + vocab = tok.vocab_size + + mask = processor._build_allow_mask(allowed, vocab) + arr = np.array(mask) + + for i in range(vocab): + if i in allowed: + assert arr[i] == 0.0, f"Position {i} should be 0.0" + else: + assert arr[i] == -np.inf, f"Position {i} should be -inf" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 65774df93..41790eaad 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -161,7 +161,7 @@ def test_batch_input_preserves_order(self, client): texts = ["first", "second", "third"] mock_engine = MagicMock() - mock_engine.model_name = "test-embed" + mock_engine.model_name = "mlx-community/all-MiniLM-L6-v2-4bit" mock_engine.embed.return_value = [ [1.0, 0.0], [0.0, 1.0], @@ -174,7 +174,7 @@ def test_batch_input_preserves_order(self, client): try: resp = client.post( "/v1/embeddings", - json={"model": "test-embed", "input": texts}, + json={"model": "mlx-community/all-MiniLM-L6-v2-4bit", "input": texts}, ) finally: srv._embedding_engine = original @@ -193,14 +193,14 @@ def test_empty_input_returns_400(self, client): import vllm_mlx.server as srv mock_engine = MagicMock() - mock_engine.model_name = "test-embed" + mock_engine.model_name = "mlx-community/all-MiniLM-L6-v2-4bit" original = srv._embedding_engine srv._embedding_engine = mock_engine try: resp = client.post( "/v1/embeddings", - json={"model": "test-embed", "input": []}, + json={"model": "mlx-community/all-MiniLM-L6-v2-4bit", "input": []}, ) finally: srv._embedding_engine = original @@ -208,7 +208,7 @@ def test_empty_input_returns_400(self, client): assert resp.status_code == 400 def test_model_hot_swap(self, client): - """Test that requesting a different model triggers reload.""" + """Test that switching to another allowlisted model triggers reload.""" import vllm_mlx.server as srv mock_engine = MagicMock() @@ -222,17 +222,22 @@ def test_model_hot_swap(self, client): try: with patch("vllm_mlx.embedding.EmbeddingEngine") as mock_cls: new_engine = MagicMock() - new_engine.model_name = "new-model" + new_engine.model_name = "mlx-community/multilingual-e5-small-mlx" new_engine.embed.return_value = [[0.9]] new_engine.count_tokens.return_value = 1 mock_cls.return_value = new_engine resp = client.post( "/v1/embeddings", - json={"model": "new-model", "input": "test"}, + json={ + "model": "mlx-community/multilingual-e5-small-mlx", + "input": "test", + }, ) assert resp.status_code == 200 - mock_cls.assert_called_once_with("new-model") + mock_cls.assert_called_once_with( + "mlx-community/multilingual-e5-small-mlx" + ) new_engine.load.assert_called_once() finally: srv._embedding_engine = original @@ -262,6 +267,18 @@ def test_model_locked_rejects_different_model(self, client): srv._embedding_engine = original_engine srv._embedding_model_locked = original_locked + def test_unknown_embedding_model_rejected(self, client): + """Test that request-time embedding loads reject unknown models.""" + resp = client.post( + "/v1/embeddings", + json={"model": "attacker/unknown-embedding", "input": "test"}, + ) + + assert resp.status_code == 400 + body = resp.json() + assert "attacker/unknown-embedding" in body["detail"] + assert "--embedding-model" in body["detail"] + # ============================================================================= # Slow Integration Test - Real Model diff --git a/tests/test_endpoint_model_policies.py b/tests/test_endpoint_model_policies.py new file mode 100644 index 000000000..b6a9e5d89 --- /dev/null +++ b/tests/test_endpoint_model_policies.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Cross-platform tests for optional endpoint model resolution policies.""" + +import pytest +from fastapi import HTTPException + +from vllm_mlx.endpoint_model_policies import ( + resolve_embedding_model_name, + resolve_stt_model_name, + resolve_tts_model_name, +) + + +class TestEmbeddingModelPolicy: + def test_allowlisted_embedding_model_passes(self): + assert ( + resolve_embedding_model_name("mlx-community/multilingual-e5-small-mlx") + == "mlx-community/multilingual-e5-small-mlx" + ) + + def test_unknown_embedding_model_rejected(self): + with pytest.raises(HTTPException) as exc_info: + resolve_embedding_model_name("attacker/unknown-embedding") + + assert exc_info.value.status_code == 400 + assert "attacker/unknown-embedding" in exc_info.value.detail + assert "--embedding-model" in exc_info.value.detail + + def test_locked_embedding_model_can_be_custom(self): + assert ( + resolve_embedding_model_name( + "custom/private-embedding", + locked_model="custom/private-embedding", + ) + == "custom/private-embedding" + ) + + def test_locked_embedding_model_rejects_other_request(self): + with pytest.raises(HTTPException) as exc_info: + resolve_embedding_model_name( + "mlx-community/all-MiniLM-L6-v2-4bit", + locked_model="custom/private-embedding", + ) + + assert exc_info.value.status_code == 400 + assert "custom/private-embedding" in exc_info.value.detail + + +class TestAudioModelPolicy: + def test_stt_alias_resolves_to_configured_model(self): + assert ( + resolve_stt_model_name("whisper-large-v3") + == "mlx-community/whisper-large-v3-mlx" + ) + + def test_stt_full_model_id_is_accepted(self): + model_name = "mlx-community/parakeet-tdt-0.6b-v2" + assert resolve_stt_model_name(model_name) == model_name + + def test_stt_unknown_model_rejected(self): + with pytest.raises(HTTPException) as exc_info: + resolve_stt_model_name("attacker/unknown-stt") + + assert exc_info.value.status_code == 400 + assert "attacker/unknown-stt" in exc_info.value.detail + assert "whisper-large-v3" in exc_info.value.detail + + def test_tts_alias_resolves_to_configured_model(self): + assert resolve_tts_model_name("kokoro") == "mlx-community/Kokoro-82M-bf16" + + def test_tts_full_model_id_is_accepted(self): + model_name = "mlx-community/chatterbox-turbo-fp16" + assert resolve_tts_model_name(model_name) == model_name + + def test_tts_unknown_model_rejected(self): + with pytest.raises(HTTPException) as exc_info: + resolve_tts_model_name("attacker/unknown-tts") + + assert exc_info.value.status_code == 400 + assert "attacker/unknown-tts" in exc_info.value.detail + assert "kokoro" in exc_info.value.detail diff --git a/tests/test_engine_core_stream_safety.py b/tests/test_engine_core_stream_safety.py new file mode 100644 index 000000000..9e8f2f21a --- /dev/null +++ b/tests/test_engine_core_stream_safety.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Regression guard for issue #407. + +EngineCore previously ran scheduler.step on a separate worker thread, which +caused mx.eval over KV cache state to raise +``RuntimeError: There is no Stream(gpu, N) in current thread`` for Llama 3.x +because mlx-lm's ``generation_stream`` is thread-local. The fix keeps step on +the event-loop thread. This test catches any regression that moves step back +onto a non-owning thread. +""" + +import asyncio +import logging + +import pytest + +# Llama 3.x reliably surfaces the cross-thread stream mismatch. Qwen3 does not. +TEST_MODEL = "mlx-community/Llama-3.2-1B-Instruct-4bit" + + +@pytest.fixture(scope="module") +def model_and_tokenizer(): + try: + from mlx_lm import load + + return load(TEST_MODEL) + except Exception as e: + pytest.skip(f"Could not load model {TEST_MODEL}: {e}") + + +@pytest.mark.anyio +async def test_engine_core_no_cross_thread_stream_error(model_and_tokenizer, caplog): + """EngineCore must run prefill + decode without raising + ``There is no Stream(gpu, N) in current thread``. + + A regression that moves ``scheduler.step`` off the event-loop thread + (e.g. re-introducing a ThreadPoolExecutor) reintroduces issue #407. + """ + from vllm_mlx import AsyncEngineCore, SamplingParams + + model, tokenizer = model_and_tokenizer + params = SamplingParams(max_tokens=5, temperature=0.0) + + caplog.set_level(logging.ERROR, logger="vllm_mlx.scheduler") + + engine = AsyncEngineCore(model, tokenizer) + await engine.__aenter__() + await asyncio.sleep(0.05) + + rid = await engine.add_request("Hello", params) + tokens = 0 + async for out in engine.stream_outputs(rid, timeout=30): + tokens += 1 + if out.finished: + break + + bg = engine.engine.scheduler.batch_generator + await engine.__aexit__(None, None, None) + + stream_errors = [ + r.message + for r in caplog.records + if "Stream(gpu" in r.message or "no Stream" in r.message + ] + assert ( + not stream_errors + ), f"scheduler logged cross-thread stream errors: {stream_errors}" + assert tokens > 0, "no tokens streamed" + assert bg is not None, ( + "batch generator was None after generation, meaning the scheduler's " + "error-recovery path fired. See issue #407." + ) diff --git a/tests/test_engine_core_thread_streams.py b/tests/test_engine_core_thread_streams.py new file mode 100644 index 000000000..8a06b8276 --- /dev/null +++ b/tests/test_engine_core_thread_streams.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Regression coverage for MLX stream/thread ownership in engine loops.""" + +import asyncio +import threading +from types import SimpleNamespace + +import pytest + + +class _SchedulerOutput: + outputs = [] + finished_request_ids = [] + + +@pytest.mark.anyio +async def test_engine_core_runs_all_scheduler_steps_on_one_worker_thread(monkeypatch): + """Continuous batching must not bounce MLX steps across threads.""" + from vllm_mlx.engine_core import EngineConfig, EngineCore + + engine = object.__new__(EngineCore) + engine.config = EngineConfig(step_interval=0, stream_interval=1) + engine._running = True + engine._steps_executed = 0 + engine._output_collectors = {} + engine._stream_states = {} + engine._finished_events = {} + + main_thread = threading.get_ident() + bind_threads: list[int] = [] + + class FakeScheduler: + batch_generator = SimpleNamespace(_partial=None) + + def __init__(self): + self.calls = 0 + self.step_threads: list[int] = [] + self.close_threads: list[int] = [] + + def has_requests(self): + return self.calls < 3 + + def step(self): + self.step_threads.append(threading.get_ident()) + self.calls += 1 + if self.calls == 3: + engine._running = False + return _SchedulerOutput() + + def _close_batch_generator(self): + self.close_threads.append(threading.get_ident()) + + scheduler = FakeScheduler() + engine.scheduler = scheduler + + def bind_streams(): + bind_threads.append(threading.get_ident()) + + monkeypatch.setattr("vllm_mlx.engine_core.bind_generation_streams", bind_streams) + + await asyncio.wait_for(engine._engine_loop(), timeout=2) + + assert scheduler.step_threads + assert len(set(scheduler.step_threads)) == 1 + assert scheduler.step_threads[0] != main_thread + assert bind_threads == [scheduler.step_threads[0]] + assert scheduler.close_threads == [scheduler.step_threads[0]] + + +@pytest.mark.anyio +async def test_mllm_scheduler_runs_steps_on_model_load_thread(monkeypatch): + """MLLM keeps generation on the event-loop thread that loaded the model.""" + from vllm_mlx.mllm_scheduler import MLLMScheduler + + scheduler = object.__new__(MLLMScheduler) + scheduler._running = True + + main_thread = threading.get_ident() + bind_threads: list[int] = [] + step_threads: list[int] = [] + close_threads: list[int] = [] + + class FakeBatchGenerator: + _partial = None + + def close(self): + close_threads.append(threading.get_ident()) + + scheduler.batch_generator = FakeBatchGenerator() + + def has_requests(): + return len(step_threads) < 3 + + def step(): + step_threads.append(threading.get_ident()) + if len(step_threads) == 3: + scheduler._running = False + + def bind_streams(): + bind_threads.append(threading.get_ident()) + + scheduler.has_requests = has_requests + scheduler.step = step + monkeypatch.setattr("vllm_mlx.mllm_scheduler.bind_generation_streams", bind_streams) + + await asyncio.wait_for(scheduler._process_loop(), timeout=2) + + assert step_threads + assert len(set(step_threads)) == 1 + assert step_threads[0] == main_thread + assert bind_threads == [main_thread] + assert close_threads == [] diff --git a/tests/test_gemma4_streaming_edge.py b/tests/test_gemma4_streaming_edge.py new file mode 100644 index 000000000..4aed1d35d --- /dev/null +++ b/tests/test_gemma4_streaming_edge.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +"""Unit tests for Gemma4 streaming parser edge cases.""" + +import sys + +sys.path.insert(0, "/Users/janhilgard/vllm-mlx-upstream") + +from vllm_mlx.reasoning.gemma4_parser import Gemma4ReasoningParser + + +def stream_parse(parser, text, token_list): + """Feed token_list to parser one-by-one, accumulating reasoning/content.""" + parser.reset_state() + accumulated = "" + reasoning_parts = [] + content_parts = [] + for tok in token_list: + prev = accumulated + accumulated += tok + msg = parser.extract_reasoning_streaming(prev, accumulated, tok) + if msg is None: + continue + if msg.reasoning: + reasoning_parts.append(msg.reasoning) + if msg.content: + content_parts.append(msg.content) + return "".join(reasoning_parts), "".join(content_parts) + + +def test_thought_label_split_across_deltas(): + """The channel-name word 'thought' can arrive split across deltas in + production streams. The parser must buffer the partial prefix rather than + leak 'th' / 'tho' / 'thou' / 'thoug' into the reasoning output. + """ + parser = Gemma4ReasoningParser() + tokens = [ + "<|channel>", + "tho", # partial "thought" label + "ught", # completes the label + "\n", + "real reasoning", + "", + "content", + ] + reasoning, content = stream_parse(parser, None, tokens) + assert reasoning == "real reasoning", f"Label leaked: {reasoning!r}" + assert content == "content", f"Got content: {content!r}" + print("[PASS] test_thought_label_split_across_deltas") + + +def test_channel_marker_split_across_tokens(): + """<|channel> alone should not leak into reasoning if response follows.""" + parser = Gemma4ReasoningParser() + tokens = [ + "<|channel>", + "thought", + "\n", + "reasoning text", + "<|channel>", + "response", + "\n", + "content text", + ] + reasoning, content = stream_parse(parser, None, tokens) + assert "<|channel>" not in reasoning, f"Marker leaked into reasoning: {reasoning!r}" + assert "<|channel>" not in content, f"Marker leaked into content: {content!r}" + assert reasoning == "reasoning text", f"Got: {reasoning!r}" + assert content == "content text", f"Got: {content!r}" + print("[PASS] test_channel_marker_split_across_tokens") + + +def test_leading_newline_after_transition(): + """Leading \\n after <|channel>response should be stripped from content.""" + parser = Gemma4ReasoningParser() + tokens = [ + "<|channel>", + "thought", + "\n", + "reasoning", + "<|channel>response", + "\n", # marker in one token, \n in next + "actual content", + ] + reasoning, content = stream_parse(parser, None, tokens) + assert content == "actual content", f"Got: {content!r}" + print("[PASS] test_leading_newline_after_transition") + + +def test_realistic_deterministic_production_stream(): + """Reproduce realistic production output with many small deltas.""" + parser = Gemma4ReasoningParser() + tokens = [ + "<|channel>", + "thought", + "\n", + "The", + " user", + " said", + ' "', + "hi", + '".', + " Plan", + ":", + " greet", + " politely", + ".", + "\n", + "<|channel>", + "response", + "\n", + "Hello", + "!", + " How", + " can", + " I", + " help", + " you", + " today", + "?", + ] + reasoning, content = stream_parse(parser, None, tokens) + assert "<|channel>" not in reasoning, f"Marker leaked: {reasoning!r}" + assert "<|channel>" not in content, f"Marker leaked: {content!r}" + assert ( + reasoning == 'The user said "hi". Plan: greet politely.\n' + ), f"Got: {reasoning!r}" + assert content == "Hello! How can I help you today?", f"Got: {content!r}" + print("[PASS] test_realistic_deterministic_production_stream") + + +def test_standard_format_with_split(): + """Standard <|channel>thought...content format, split token.""" + parser = Gemma4ReasoningParser() + tokens = [ + "<|channel>", + "thought", + "\n", + "reasoning", + "", + "content", + ] + reasoning, content = stream_parse(parser, None, tokens) + assert reasoning == "reasoning", f"Got: {reasoning!r}" + assert content == "content", f"Got: {content!r}" + print("[PASS] test_standard_format_with_split") + + +def test_no_transition_stays_in_thinking(): + """When model never emits transition marker, all goes to reasoning.""" + parser = Gemma4ReasoningParser() + tokens = [ + "<|channel>", + "thought", + "\n", + "reasoning with", + " no", + " transition", + ] + reasoning, content = stream_parse(parser, None, tokens) + assert reasoning == "reasoning with no transition", f"Got: {reasoning!r}" + assert content == "", f"Got: {content!r}" + print("[PASS] test_no_transition_stays_in_thinking") + + +def test_finalize_stream_fallback(): + """B: If stream ends in thinking phase, finalize should allow fallback.""" + parser = Gemma4ReasoningParser() + tokens = [ + "<|channel>", + "thought", + "\n", + "draft answer here", + ] + reasoning, content = stream_parse(parser, None, tokens) + # Should still classify as reasoning (current correct behavior) + assert reasoning == "draft answer here", f"Got: {reasoning!r}" + # Finalize hook should return any pending partial marker buffer + msg = parser.finalize_stream() + # If parser buffered a partial <|channel> at end of stream, it should flush here + assert msg is None or msg.reasoning is not None or msg.content is not None + print("[PASS] test_finalize_stream_fallback") + + +def test_partial_marker_at_end_flushed(): + """If stream ends with buffered partial marker, finalize emits it.""" + parser = Gemma4ReasoningParser() + tokens = [ + "<|channel>", + "thought", + "\n", + "text", + "<|channel>", # partial — could be transition or just end + ] + reasoning, content = stream_parse(parser, None, tokens) + # At this point, "<|channel>" should be buffered, not emitted + assert "<|channel>" not in reasoning, f"Leaked: {reasoning!r}" + # Finalize flushes the buffered marker as reasoning (it wasn't transition) + msg = parser.finalize_stream() + if msg and msg.reasoning: + reasoning += msg.reasoning + assert reasoning == "text<|channel>", f"Got: {reasoning!r}" + print("[PASS] test_partial_marker_at_end_flushed") + + +if __name__ == "__main__": + test_channel_marker_split_across_tokens() + test_leading_newline_after_transition() + test_realistic_deterministic_production_stream() + test_standard_format_with_split() + test_no_transition_stays_in_thinking() + test_finalize_stream_fallback() + test_partial_marker_at_end_flushed() + print("\nAll tests passed!") diff --git a/tests/test_gemma4_tool_parser.py b/tests/test_gemma4_tool_parser.py index 179915442..3ec668f85 100644 --- a/tests/test_gemma4_tool_parser.py +++ b/tests/test_gemma4_tool_parser.py @@ -167,6 +167,43 @@ def test_string_with_newline_and_quote(self): args = json.loads(result.tool_calls[0]["arguments"]) assert args == {"text": 'line1\nline2 said "hello"'} + def test_bare_string_value_without_delimiters(self): + """Nullable type (e.g. ["string", "null"]) makes the template skip the + <|"|> wrap around string values. The parser must still produce valid + JSON with the value as a string. + Reference: llama.cpp PR #21327. + """ + output = "<|tool_call>call:set_state{domain:light}" + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"domain": "light"} + + def test_bare_string_mixed_with_number_and_bool(self): + """Bare string value must not interfere with numeric/bool parsing.""" + output = "<|tool_call>call:update{name:alice,count:5,active:true}" + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"name": "alice", "count": 5, "active": True} + + def test_bare_string_preserves_null_and_bool_literals(self): + """null/true/false must NOT be treated as bare strings.""" + output = ( + "<|tool_call>call:cfg{flag:null,ready:true,done:false,name:bob}" + ) + result = self.parser.extract_tool_calls(output) + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"flag": None, "ready": True, "done": False, "name": "bob"} + + def test_bare_string_in_array(self): + """Enum-without-type: array of bare strings should be quoted per element.""" + output = "<|tool_call>call:filter{tags:[alpha,beta,gamma]}" + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"tags": ["alpha", "beta", "gamma"]} + class TestGemma4ToolParserStreaming: """Test streaming tool call extraction.""" @@ -237,4 +274,47 @@ def test_registered_in_manager(self): def test_native_format_false(self): assert Gemma4ToolParser.SUPPORTS_NATIVE_TOOL_FORMAT is False + + def test_extra_stop_tokens_declares_tool_response(self): + """Gemma 4 treats <|tool_response> (id 50) as end-of-generation + after a tool call. The parser exposes it so the server can merge it + into the request's stop sequences. + Reference: llama.cpp PR #21418. + """ + parser = Gemma4ToolParser() + assert "<|tool_response>" in parser.extra_stop_tokens + + def test_abstract_parser_default_empty_stop_tokens(self): + """Other parsers that don't override keep an empty default.""" + from vllm_mlx.tool_parsers.abstract_tool_parser import ToolParser + + assert ToolParser.extra_stop_tokens == [] + + def test_merge_helper_adds_parser_extras(self): + """get_parser_stop_tokens adds the parser's EOG tokens on top of user stops.""" + from vllm_mlx.tool_parsers import get_parser_stop_tokens + + merged = get_parser_stop_tokens("gemma4", ["END"]) + assert "END" in merged + assert "<|tool_response>" in merged + + def test_merge_helper_dedupes(self): + """Parser extra tokens already present in user stops aren't duplicated.""" + from vllm_mlx.tool_parsers import get_parser_stop_tokens + + merged = get_parser_stop_tokens("gemma4", ["<|tool_response>"]) + assert merged.count("<|tool_response>") == 1 + + def test_merge_helper_unknown_parser_is_passthrough(self): + """Unknown parser name leaves user stops untouched.""" + from vllm_mlx.tool_parsers import get_parser_stop_tokens + + assert get_parser_stop_tokens("nonexistent_parser_xyz", ["A"]) == ["A"] + + def test_merge_helper_none_parser_is_passthrough(self): + """No parser name returns user stops as-is.""" + from vllm_mlx.tool_parsers import get_parser_stop_tokens + + assert get_parser_stop_tokens(None, ["A"]) == ["A"] + assert get_parser_stop_tokens(None, None) == [] assert Gemma4ToolParser.supports_native_format() is False diff --git a/tests/test_kv_cache_quantization.py b/tests/test_kv_cache_quantization.py index 05d36670c..f8b770aa1 100644 --- a/tests/test_kv_cache_quantization.py +++ b/tests/test_kv_cache_quantization.py @@ -2,11 +2,12 @@ """Tests for KV cache quantization in prefix cache.""" import mlx.core as mx -from mlx_lm.models.cache import KVCache, QuantizedKVCache +from mlx_lm.models.cache import KVCache from vllm_mlx.memory_cache import ( MemoryAwarePrefixCache, MemoryCacheConfig, + _QuantizedCacheWrapper, _dequantize_cache, _quantize_cache, _trim_to_offset, @@ -37,7 +38,7 @@ def test_quantize_produces_quantized_cache(self): quantized = _quantize_cache(cache, bits=8, group_size=64) assert len(quantized) == len(cache) for layer in quantized: - assert isinstance(layer, QuantizedKVCache) + assert isinstance(layer, _QuantizedCacheWrapper) def test_dequantize_produces_kv_cache(self): cache = _make_kv_cache() @@ -107,7 +108,7 @@ def test_non_kvcache_layers_preserved(self): cache = [kv, fake_mamba] quantized = _quantize_cache(cache, bits=8, group_size=64) - assert isinstance(quantized[0], QuantizedKVCache) + assert isinstance(quantized[0], _QuantizedCacheWrapper) assert isinstance(quantized[1], dict) # Preserved as-is restored = _dequantize_cache(quantized) @@ -196,7 +197,7 @@ def test_store_fetch_with_quantization(self): # Internally stored as quantized stored_entry = list(pc._entries.values())[0] for layer in stored_entry.cache: - assert isinstance(layer, QuantizedKVCache) + assert isinstance(layer, _QuantizedCacheWrapper) # Fetched as dequantized KVCache fetched, remaining = pc.fetch(tokens) @@ -312,7 +313,7 @@ def test_store_quantizes_above_threshold(self): stored_entry = list(pc._entries.values())[0] for layer in stored_entry.cache: assert isinstance( - layer, QuantizedKVCache + layer, _QuantizedCacheWrapper ), "Long sequences should be quantized" def test_trim_applied_without_quantization(self): diff --git a/tests/test_lifecycle_cli.py b/tests/test_lifecycle_cli.py new file mode 100644 index 000000000..e2a8ba4a8 --- /dev/null +++ b/tests/test_lifecycle_cli.py @@ -0,0 +1,424 @@ +# SPDX-License-Identifier: Apache-2.0 +"""CLI forwarding tests for lifecycle configuration flags.""" + +from __future__ import annotations + +import sys +from types import SimpleNamespace + +import pytest + + +@pytest.fixture(autouse=True) +def restore_server_globals(): + """Restore mutated server module globals between lifecycle tests.""" + import vllm_mlx.server as srv + + sentinel = object() + global_names = ( + "_engine", + "_model_name", + "_model_path", + "_default_model_key", + "_default_max_tokens", + "_default_timeout", + "_default_temperature", + "_default_top_p", + "_force_mllm_model", + "_auto_unload_idle_seconds", + "_lazy_load_model", + "_residency_manager", + "_lifecycle_task", + "_lifespan_active", + "_mcp_manager", + "_mcp_executor", + "_embedding_engine", + "_embedding_model_locked", + "_api_key", + "_auth_warning_logged", + "_rate_limiter", + "_reasoning_parser", + "_enable_auto_tool_choice", + "_tool_call_parser", + "_tool_parser_instance", + "_idle_unload_enabled", + ) + snapshot = {name: getattr(srv, name, sentinel) for name in global_names} + + yield + + leaked_task = getattr(srv, "_lifecycle_task", None) + original_task = snapshot["_lifecycle_task"] + if ( + leaked_task is not sentinel + and leaked_task is not None + and leaked_task is not original_task + and not leaked_task.done() + ): + leaked_task.cancel() + + for name, value in snapshot.items(): + if value is sentinel: + if hasattr(srv, name): + delattr(srv, name) + else: + setattr(srv, name, value) + + # _idle_unload_enabled is a lazily-created asyncio.Event bound to whatever + # event loop was running when _get_idle_unload_event() was first called. + # Reset to None so the next test gets a fresh Event on its own loop. + srv._idle_unload_enabled = None + + +class TestLifecycleCli: + """Lock in the first lifecycle CLI surface area.""" + + def test_main_parses_auto_unload_idle_seconds_flag(self, monkeypatch): + """The top-level CLI should accept an idle-unload timeout knob.""" + import vllm_mlx.cli as cli + + captured = {} + + def fake_serve_command(args): + captured["auto_unload_idle_seconds"] = args.auto_unload_idle_seconds + + monkeypatch.setattr(cli, "serve_command", fake_serve_command) + monkeypatch.setattr( + cli.sys, + "argv", + [ + "vllm-mlx", + "serve", + "mlx-community/Qwen3-0.6B-8bit", + "--auto-unload-idle-seconds", + "300", + ], + ) + + cli.main() + + assert captured["auto_unload_idle_seconds"] == 300 + + def test_main_parses_lazy_load_model_flag(self, monkeypatch): + """The top-level CLI should accept lazy lifecycle startup.""" + import vllm_mlx.cli as cli + + captured = {} + + def fake_serve_command(args): + captured["lazy_load_model"] = args.lazy_load_model + + monkeypatch.setattr(cli, "serve_command", fake_serve_command) + monkeypatch.setattr( + cli.sys, + "argv", + [ + "vllm-mlx", + "serve", + "mlx-community/Qwen3-0.6B-8bit", + "--lazy-load-model", + ], + ) + + cli.main() + + assert captured["lazy_load_model"] is True + + def test_main_defaults_lazy_load_model_to_false(self, monkeypatch): + """Serve startup should stay eager unless the user explicitly opts in.""" + import vllm_mlx.cli as cli + + captured = {} + + def fake_serve_command(args): + captured["lazy_load_model"] = args.lazy_load_model + + monkeypatch.setattr(cli, "serve_command", fake_serve_command) + monkeypatch.setattr( + cli.sys, + "argv", + [ + "vllm-mlx", + "serve", + "mlx-community/Qwen3-0.6B-8bit", + ], + ) + + cli.main() + + assert captured["lazy_load_model"] is False + + def test_serve_command_wires_auto_unload_idle_seconds_into_load_model( + self, monkeypatch + ): + """serve_command should pass lifecycle config into model loading.""" + import uvicorn + + import vllm_mlx.cli as cli + import vllm_mlx.server as srv + + captured = {} + + def fake_load_model(*args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + + monkeypatch.setattr(srv, "load_model", fake_load_model) + monkeypatch.setattr(uvicorn, "run", lambda *args, **kwargs: None) + + args = SimpleNamespace( + model="mlx-community/Qwen3-0.6B-8bit", + host="127.0.0.1", + port=8000, + max_num_seqs=256, + prefill_batch_size=8, + completion_batch_size=32, + enable_prefix_cache=True, + disable_prefix_cache=False, + prefix_cache_size=100, + cache_memory_mb=None, + cache_memory_percent=0.20, + no_memory_aware_cache=False, + kv_cache_quantization=False, + kv_cache_quantization_bits=8, + kv_cache_quantization_group_size=64, + kv_cache_min_quantize_tokens=256, + stream_interval=7, + max_tokens=32768, + continuous_batching=False, + use_paged_cache=False, + paged_cache_block_size=64, + max_cache_blocks=1000, + chunked_prefill_tokens=0, + enable_mtp=False, + mtp_num_draft_tokens=1, + mtp_optimistic=False, + prefill_step_size=2048, + specprefill=False, + specprefill_threshold=8192, + specprefill_keep_pct=0.3, + specprefill_draft_model=None, + mcp_config=None, + api_key=None, + rate_limit=0, + timeout=300.0, + enable_auto_tool_choice=False, + tool_call_parser=None, + reasoning_parser=None, + mllm=False, + default_temperature=None, + default_top_p=None, + default_chat_template_kwargs=None, + served_model_name=None, + embedding_model=None, + gpu_memory_utilization=0.90, + enable_metrics=False, + download_timeout=120, + download_retries=3, + mllm_prefill_step_size=None, + lazy_load_model=True, + auto_unload_idle_seconds=300, + ) + + cli.serve_command(args) + + assert captured["kwargs"]["stream_interval"] == 1 + assert captured["kwargs"]["auto_unload_idle_seconds"] == 300 + assert captured["kwargs"]["lazy_load_model"] is True + + def test_serve_command_preserves_mtp_scheduler_config_with_residency( + self, monkeypatch + ): + """serve_command should keep batching/MTP config intact alongside residency.""" + import uvicorn + + import vllm_mlx.cli as cli + import vllm_mlx.server as srv + + captured = {} + + def fake_load_model(*args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + + class FakeSchedulerConfig: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + monkeypatch.setattr(srv, "load_model", fake_load_model) + monkeypatch.setattr(uvicorn, "run", lambda *args, **kwargs: None) + monkeypatch.setitem( + sys.modules, + "vllm_mlx.scheduler", + SimpleNamespace(SchedulerConfig=FakeSchedulerConfig), + ) + + args = SimpleNamespace( + model="mlx-community/Qwen3-0.6B-8bit", + host="127.0.0.1", + port=8000, + max_num_seqs=256, + prefill_batch_size=8, + completion_batch_size=32, + enable_prefix_cache=True, + disable_prefix_cache=False, + prefix_cache_size=100, + cache_memory_mb=None, + cache_memory_percent=0.20, + no_memory_aware_cache=False, + kv_cache_quantization=False, + kv_cache_quantization_bits=8, + kv_cache_quantization_group_size=64, + kv_cache_min_quantize_tokens=256, + stream_interval=7, + max_tokens=32768, + continuous_batching=True, + use_paged_cache=False, + paged_cache_block_size=64, + max_cache_blocks=1000, + chunked_prefill_tokens=0, + enable_mtp=True, + mtp_num_draft_tokens=4, + mtp_optimistic=True, + prefill_step_size=2048, + specprefill=False, + specprefill_threshold=8192, + specprefill_keep_pct=0.3, + specprefill_draft_model=None, + mcp_config=None, + api_key=None, + rate_limit=0, + timeout=300.0, + enable_auto_tool_choice=False, + tool_call_parser=None, + reasoning_parser=None, + mllm=False, + default_temperature=None, + default_top_p=None, + default_chat_template_kwargs=None, + served_model_name=None, + embedding_model=None, + gpu_memory_utilization=0.90, + enable_metrics=False, + download_timeout=120, + download_retries=3, + mllm_prefill_step_size=0, + lazy_load_model=True, + auto_unload_idle_seconds=300, + ) + + cli.serve_command(args) + + scheduler_config = captured["kwargs"]["scheduler_config"] + assert captured["kwargs"]["use_batching"] is True + assert captured["kwargs"]["stream_interval"] == 7 + assert captured["kwargs"]["auto_unload_idle_seconds"] == 300 + assert captured["kwargs"]["lazy_load_model"] is True + assert scheduler_config.enable_mtp is True + assert scheduler_config.mtp_num_draft_tokens == 4 + assert scheduler_config.mtp_optimistic is True + + def test_server_main_preserves_use_batching_with_residency_flags(self, monkeypatch): + """server.main should forward use_batching and residency knobs together.""" + import vllm_mlx.server as srv + + captured = {} + + def fake_load_model(*args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + + monkeypatch.setattr(srv, "load_model", fake_load_model) + monkeypatch.setattr(srv, "load_embedding_model", lambda *args, **kwargs: None) + monkeypatch.setattr(srv.uvicorn, "run", lambda *args, **kwargs: None) + monkeypatch.setattr( + sys, + "argv", + [ + "vllm_mlx.server", + "--model", + "mlx-community/Qwen3-0.6B-8bit", + "--continuous-batching", + "--auto-unload-idle-seconds", + "300", + "--lazy-load-model", + ], + ) + + srv.main() + + assert captured["kwargs"]["use_batching"] is True + assert captured["kwargs"]["auto_unload_idle_seconds"] == 300.0 + assert captured["kwargs"]["lazy_load_model"] is True + + def test_serve_command_describes_lazy_startup_without_claiming_model_is_loaded( + self, monkeypatch, capsys + ): + """Lazy startup output should mention the first request, not claim eager load.""" + import uvicorn + + import vllm_mlx.cli as cli + import vllm_mlx.server as srv + + monkeypatch.setattr(srv, "load_model", lambda *args, **kwargs: None) + monkeypatch.setattr(uvicorn, "run", lambda *args, **kwargs: None) + + args = SimpleNamespace( + model="mlx-community/Qwen3-0.6B-8bit", + host="127.0.0.1", + port=8000, + max_num_seqs=256, + prefill_batch_size=8, + completion_batch_size=32, + enable_prefix_cache=True, + disable_prefix_cache=False, + prefix_cache_size=100, + cache_memory_mb=None, + cache_memory_percent=0.20, + no_memory_aware_cache=False, + kv_cache_quantization=False, + kv_cache_quantization_bits=8, + kv_cache_quantization_group_size=64, + kv_cache_min_quantize_tokens=256, + stream_interval=1, + max_tokens=32768, + continuous_batching=False, + use_paged_cache=False, + paged_cache_block_size=64, + max_cache_blocks=1000, + chunked_prefill_tokens=0, + enable_mtp=False, + mtp_num_draft_tokens=1, + mtp_optimistic=False, + prefill_step_size=2048, + specprefill=False, + specprefill_threshold=8192, + specprefill_keep_pct=0.3, + specprefill_draft_model=None, + mcp_config=None, + api_key=None, + rate_limit=0, + timeout=300.0, + enable_auto_tool_choice=False, + tool_call_parser=None, + reasoning_parser=None, + mllm=False, + default_temperature=None, + default_top_p=None, + default_chat_template_kwargs=None, + served_model_name=None, + embedding_model=None, + gpu_memory_utilization=0.90, + enable_metrics=False, + download_timeout=120, + download_retries=3, + mllm_prefill_step_size=None, + lazy_load_model=True, + auto_unload_idle_seconds=0.0, + ) + + cli.serve_command(args) + out = capsys.readouterr().out + + assert "Loading model: mlx-community/Qwen3-0.6B-8bit" not in out + assert "first request" in out diff --git a/tests/test_lifecycle_manager.py b/tests/test_lifecycle_manager.py new file mode 100644 index 000000000..c31c71c1d --- /dev/null +++ b/tests/test_lifecycle_manager.py @@ -0,0 +1,1413 @@ +# SPDX-License-Identifier: Apache-2.0 +"""ResidencyManager unit tests — lifecycle state machine, cancellation, and shutdown.""" + +from __future__ import annotations + +import asyncio +import threading +import time +from contextlib import suppress + +import pytest + + +async def _wait_for_resident_state(manager, model_key: str, state: str) -> None: + """Poll until a resident reaches the expected state.""" + while manager.get_status(model_key)["state"] != state: + await asyncio.sleep(0) + + +class TestResidencyManagerContracts: + """Lock in the high-risk lifecycle invariants at the manager layer.""" + + @pytest.mark.anyio + async def test_concurrent_acquire_single_flights_initial_load(self): + """Concurrent acquires for one model should perform only one load.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + create_calls = 0 + started = 0 + + class FakeEngine: + async def start(self): + nonlocal started + started += 1 + + async def stop(self): + return None + + async def engine_factory(spec): + nonlocal create_calls + create_calls += 1 + await asyncio.sleep(0) + return FakeEngine() + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=300, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + engines = await asyncio.gather( + manager.acquire("default"), + manager.acquire("default"), + manager.acquire("default"), + ) + + assert create_calls == 1 + assert started == 1 + assert len({id(engine) for engine in engines}) == 1 + + for _ in engines: + await manager.release("default") + + @pytest.mark.anyio + async def test_unload_if_idle_respects_active_requests(self): + """Idle unload should be blocked while a resident still has users.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + stopped = 0 + + class FakeEngine: + async def start(self): + return None + + async def stop(self): + nonlocal stopped + stopped += 1 + + now = {"value": 1000.0} + + async def engine_factory(spec): + return FakeEngine() + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: now["value"], + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + await manager.acquire("default") + now["value"] += 120.0 + + unloaded = await manager.unload_if_idle("default") + + assert unloaded is False + assert stopped == 0 + assert manager.get_status("default")["state"] == "loaded" + + await manager.release("default") + + @pytest.mark.anyio + async def test_release_updates_last_used_at(self): + """release() should refresh the timestamp used by idle-unload policy.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + now = {"value": 1000.0} + + class FakeEngine: + async def start(self): + return None + + async def stop(self): + return None + + async def engine_factory(spec): + return FakeEngine() + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: now["value"], + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + await manager.acquire("default") + before_release = manager.get_status("default")["last_used_at"] + + now["value"] = 1125.0 + await manager.release("default") + + after_release = manager.get_status("default")["last_used_at"] + assert before_release != after_release + assert after_release == 1125.0 + + @pytest.mark.anyio + async def test_unload_after_idle_threshold_and_reload(self): + """A released idle resident should unload and later reload cleanly.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + created = 0 + stopped = 0 + started = 0 + + class FakeEngine: + def __init__(self, generation): + self.generation = generation + + async def start(self): + nonlocal started + started += 1 + return None + + async def stop(self): + nonlocal stopped + stopped += 1 + + now = {"value": 1000.0} + + async def engine_factory(spec): + nonlocal created + created += 1 + return FakeEngine(created) + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: now["value"], + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + first = await manager.acquire("default") + await manager.release("default") + + now["value"] += 120.0 + unloaded = await manager.unload_if_idle("default") + + assert unloaded is True + assert stopped == 1 + assert manager.get_status("default")["state"] == "unloaded" + + second = await manager.acquire("default") + + assert created == 2 + assert started == 2 + assert first is not second + assert manager.get_status("default")["state"] == "loaded" + + await manager.release("default") + + @pytest.mark.anyio + async def test_cancelled_cold_load_does_not_wedge_future_acquires(self): + """Cancelling one waiter should not poison the shared resident load.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + created = 0 + started = 0 + stopped = 0 + load_gate = asyncio.Event() + + class FakeEngine: + async def start(self): + nonlocal started + started += 1 + + async def stop(self): + nonlocal stopped + stopped += 1 + + async def engine_factory(spec): + nonlocal created + created += 1 + await load_gate.wait() + return FakeEngine() + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + first = asyncio.create_task(manager.acquire("default")) + await asyncio.sleep(0) + + first.cancel() + with pytest.raises(asyncio.CancelledError): + await first + + load_gate.set() + second = await manager.acquire("default") + + assert second is not None + assert created == 2 + assert started == 1 + assert stopped == 0 + assert manager.get_status("default")["state"] == "loaded" + + await manager.release("default") + + @pytest.mark.anyio + async def test_last_cancelled_waiter_cancels_cold_load(self): + """Cancelling the final waiter should unwind the in-flight resident load.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + created = 0 + started = 0 + stopped = 0 + load_gate = asyncio.Event() + start_entered = asyncio.Event() + start_cancelled = asyncio.Event() + + class FakeEngine: + async def start(self): + nonlocal started + started += 1 + start_entered.set() + try: + await load_gate.wait() + except asyncio.CancelledError: + start_cancelled.set() + raise + + async def stop(self): + nonlocal stopped + stopped += 1 + + async def engine_factory(spec): + nonlocal created + created += 1 + return FakeEngine() + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + first = asyncio.create_task(manager.acquire("default")) + try: + await asyncio.wait_for(start_entered.wait(), timeout=1.0) + + first.cancel() + with pytest.raises(asyncio.CancelledError): + await first + + await asyncio.wait_for(start_cancelled.wait(), timeout=1.0) + + status = manager.get_status("default") + assert status["state"] == "unloaded" + assert manager._residents["default"]._loading_task is None + assert stopped == 1 + + load_gate.set() + second = await asyncio.wait_for(manager.acquire("default"), timeout=1.0) + + assert second is not None + assert created == 2 + assert started == 2 + assert manager.get_status("default")["state"] == "loaded" + + await manager.release("default") + finally: + load_gate.set() + if not first.done(): + first.cancel() + with suppress(asyncio.CancelledError): + await first + + @pytest.mark.anyio + async def test_cancelled_waiter_does_not_cancel_shared_cold_load(self): + """A canceled waiter should not kill a cold load another waiter still needs.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + created = 0 + started = 0 + stopped = 0 + load_gate = asyncio.Event() + start_entered = asyncio.Event() + start_cancelled = asyncio.Event() + + class FakeEngine: + async def start(self): + nonlocal started + started += 1 + start_entered.set() + try: + await load_gate.wait() + except asyncio.CancelledError: + start_cancelled.set() + raise + + async def stop(self): + nonlocal stopped + stopped += 1 + + async def engine_factory(spec): + nonlocal created + created += 1 + return FakeEngine() + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + first = asyncio.create_task(manager.acquire("default")) + second = asyncio.create_task(manager.acquire("default")) + try: + await asyncio.wait_for(start_entered.wait(), timeout=1.0) + + first.cancel() + with pytest.raises(asyncio.CancelledError): + await first + + await asyncio.sleep(0) + assert not start_cancelled.is_set() + + load_gate.set() + engine = await asyncio.wait_for(second, timeout=1.0) + + assert engine is not None + assert created == 1 + assert started == 1 + assert stopped == 0 + assert manager.get_status("default")["state"] == "loaded" + + await manager.release("default") + finally: + load_gate.set() + for task in (first, second): + if not task.done(): + task.cancel() + with suppress(asyncio.CancelledError): + await task + + @pytest.mark.anyio + async def test_cancelled_last_waiter_suppresses_load_failure_from_cancel(self): + """Abandoned cold loads should still surface as cancellation to the waiter.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + created = 0 + started = 0 + stopped = 0 + start_entered = asyncio.Event() + + class FakeEngine: + def __init__(self, generation): + self.generation = generation + + async def start(self): + nonlocal started + started += 1 + if self.generation != 1: + return None + start_entered.set() + try: + await asyncio.Event().wait() + except asyncio.CancelledError as exc: + raise RuntimeError("boom-from-start") from exc + + async def stop(self): + nonlocal stopped + stopped += 1 + + async def engine_factory(spec): + nonlocal created + created += 1 + return FakeEngine(created) + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + first = asyncio.create_task(manager.acquire("default")) + + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(start_entered.wait(), timeout=1.0) + first.cancel() + await first + + status = manager.get_status("default") + assert status["state"] == "unloaded" + assert status["last_error"] is None + assert manager._residents["default"]._loading_task is None + assert stopped == 1 + + second = await asyncio.wait_for(manager.acquire("default"), timeout=1.0) + + assert second is not None + assert second.generation == 2 + assert created == 2 + assert started == 2 + assert manager.get_status("default")["state"] == "loaded" + + await manager.release("default") + + @pytest.mark.anyio + async def test_cancelled_prepare_for_start_does_not_finish_after_stop(self): + """Cancelled cold loads should not let prepare_for_start keep mutating after stop.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + created = 0 + started = 0 + stopped = 0 + prepare_entered = threading.Event() + allow_prepare_finish = threading.Event() + prepare_finished = threading.Event() + prepare_finished_after_stop = threading.Event() + stop_called = threading.Event() + + class FakeEngine: + def __init__(self, generation): + self.generation = generation + + def prepare_for_start(self): + prepare_entered.set() + allow_prepare_finish.wait() + time.sleep(0.01) + if stop_called.is_set(): + prepare_finished_after_stop.set() + prepare_finished.set() + + async def start(self): + nonlocal started + started += 1 + + async def stop(self): + nonlocal stopped + stopped += 1 + stop_called.set() + + async def engine_factory(spec): + nonlocal created + created += 1 + return FakeEngine(created) + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + first = asyncio.create_task(manager.acquire("default")) + try: + assert await asyncio.wait_for( + asyncio.to_thread(prepare_entered.wait, 1.0), + timeout=1.0, + ) + + first.cancel() + allow_prepare_finish.set() + + with pytest.raises(asyncio.CancelledError): + await first + + assert await asyncio.wait_for( + asyncio.to_thread(prepare_finished.wait, 1.0), + timeout=1.0, + ) + assert not prepare_finished_after_stop.is_set() + + await asyncio.wait_for( + _wait_for_resident_state(manager, "default", "unloaded"), + timeout=1.0, + ) + + status = manager.get_status("default") + assert status["state"] == "unloaded" + assert manager._residents["default"]._loading_task is None + + second = await asyncio.wait_for(manager.acquire("default"), timeout=1.0) + + assert second is not None + assert second.generation == 2 + assert created == 2 + assert started == 1 + assert stopped == 1 + assert manager.get_status("default")["state"] == "loaded" + + await manager.release("default") + finally: + allow_prepare_finish.set() + if not first.done(): + first.cancel() + with suppress(asyncio.CancelledError): + await first + + @pytest.mark.anyio + async def test_late_joining_waiter_retries_abandoned_cold_load(self): + """A waiter joining during abandoned-load cleanup should retry instead of inheriting cancel.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + created = 0 + started = 0 + stopped = 0 + start_entered = asyncio.Event() + stop_entered = asyncio.Event() + allow_stop = asyncio.Event() + + class FakeEngine: + def __init__(self, generation): + self.generation = generation + + async def start(self): + nonlocal started + started += 1 + if self.generation != 1: + return None + start_entered.set() + await asyncio.Event().wait() + + async def stop(self): + nonlocal stopped + stopped += 1 + stop_entered.set() + await allow_stop.wait() + + async def engine_factory(spec): + nonlocal created + created += 1 + return FakeEngine(created) + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + first = asyncio.create_task(manager.acquire("default")) + try: + await asyncio.wait_for(start_entered.wait(), timeout=1.0) + + first.cancel() + + await asyncio.wait_for(stop_entered.wait(), timeout=1.0) + + second = asyncio.create_task(manager.acquire("default")) + try: + await asyncio.sleep(0) + + allow_stop.set() + with pytest.raises(asyncio.CancelledError): + await first + engine = await asyncio.wait_for(second, timeout=1.0) + + assert engine is not None + assert engine.generation == 2 + assert created == 2 + assert started == 2 + assert stopped == 1 + assert manager.get_status("default")["state"] == "loaded" + finally: + if not second.done(): + second.cancel() + with suppress(asyncio.CancelledError): + await second + + await manager.release("default") + finally: + allow_stop.set() + if not first.done(): + first.cancel() + with suppress(asyncio.CancelledError): + await first + + @pytest.mark.anyio + async def test_cancelled_shared_load_during_state_commit_recovers_cleanly(self): + """Cancelling the shared load during state commit should stop the engine and recover.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + created = 0 + started = 0 + stopped = 0 + + class CommitGateLock: + def __init__(self): + self._real = asyncio.Lock() + self.block_next_enter = False + self.commit_waiting = asyncio.Event() + self.allow_commit = asyncio.Event() + + async def __aenter__(self): + if self.block_next_enter: + self.block_next_enter = False + self.commit_waiting.set() + await self.allow_commit.wait() + await self._real.acquire() + return self + + async def __aexit__(self, exc_type, exc, tb): + self._real.release() + + lock = CommitGateLock() + + class FakeEngine: + async def start(self): + nonlocal started + started += 1 + + async def stop(self): + nonlocal stopped + stopped += 1 + + async def engine_factory(spec): + nonlocal created + created += 1 + return FakeEngine() + + async def on_engine_loaded(spec, engine): + lock.block_next_enter = True + + manager = ResidencyManager( + engine_factory=engine_factory, + on_engine_loaded=on_engine_loaded, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager._lock = lock + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + first = asyncio.create_task(manager.acquire("default")) + await lock.commit_waiting.wait() + + loading_task = manager._residents["default"]._loading_task + assert loading_task is not None + loading_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await first + + assert stopped == 1 + assert manager.get_status("default")["state"] == "unloaded" + + lock.allow_commit.set() + second = await asyncio.wait_for(manager.acquire("default"), timeout=1.0) + + assert created == 2 + assert started == 2 + assert manager.get_status("default")["state"] == "loaded" + + await manager.release("default") + + @pytest.mark.anyio + async def test_cancelled_load_cleanup_handles_legacy_tasks_without_uncancel_api( + self, monkeypatch + ): + """Cancellation cleanup should still work on supported runtimes without Task.uncancel().""" + import vllm_mlx.lifecycle as lifecycle_mod + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager, ResidentState + + stopped = 0 + + class LegacyTask: + def cancel(self): + return None + + class FakeEngine: + async def stop(self): + nonlocal stopped + stopped += 1 + + async def engine_factory(spec): + raise AssertionError("engine_factory should not be called") + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + resident = manager._residents["default"] + resident.state = ResidentState.LOADING + resident._loading_task = object() + + monkeypatch.setattr( + lifecycle_mod.asyncio, + "current_task", + lambda: LegacyTask(), + ) + + await manager._cleanup_cancelled_load(resident, FakeEngine()) + + assert stopped == 1 + assert resident.state == ResidentState.UNLOADED + assert resident._loading_task is None + + @pytest.mark.anyio + async def test_shutdown_cancels_inflight_cold_load(self): + """shutdown() should not allow a cold load to complete afterward.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + started = 0 + load_gate = asyncio.Event() + + class FakeEngine: + async def start(self): + nonlocal started + started += 1 + + async def stop(self): + return None + + async def engine_factory(spec): + await load_gate.wait() + return FakeEngine() + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + acquire_task = asyncio.create_task(manager.acquire("default")) + await asyncio.sleep(0) + + shutdown_task = asyncio.create_task(manager.shutdown()) + await asyncio.sleep(0) + + load_gate.set() + await shutdown_task + + with pytest.raises(asyncio.CancelledError): + await acquire_task + + assert started == 0 + assert manager.get_engine("default") is None + assert manager.get_status("default")["state"] == "unloaded" + + @pytest.mark.anyio + async def test_shutdown_canceled_prepare_error_unwinds_to_unloaded(self): + """shutdown() should suppress prepare errors from a canceled cold load.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + prepare_entered = threading.Event() + allow_prepare_finish = threading.Event() + started = 0 + + class FakeEngine: + def prepare_for_start(self): + prepare_entered.set() + allow_prepare_finish.wait() + raise RuntimeError("prepare boom") + + async def start(self): + nonlocal started + started += 1 + + async def stop(self): + return None + + async def engine_factory(spec): + return FakeEngine() + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + acquire_task = asyncio.create_task(manager.acquire("default")) + assert await asyncio.wait_for( + asyncio.to_thread(prepare_entered.wait, 1.0), + timeout=1.0, + ) + + shutdown_task = asyncio.create_task(manager.shutdown()) + await asyncio.sleep(0) + allow_prepare_finish.set() + await shutdown_task + + with pytest.raises(asyncio.CancelledError): + await acquire_task + + status = manager.get_status("default") + assert started == 0 + assert manager.get_engine("default") is None + assert status["state"] == "unloaded" + assert status["last_error"] is None + + @pytest.mark.anyio + async def test_shutdown_raises_if_loaded_resident_cannot_unload(self): + """shutdown() should not report success when resident unload fails.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + class FakeEngine: + async def start(self): + return None + + async def stop(self): + raise RuntimeError("stop boom") + + async def engine_factory(spec): + return FakeEngine() + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + await manager.acquire("default") + await manager.release("default") + + with pytest.raises(RuntimeError): + await manager.shutdown() + + @pytest.mark.anyio + async def test_shutdown_attempts_later_residents_after_one_unload_failure(self): + """shutdown() should keep unloading later residents after one unload fails.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + stopped = [] + + class FakeEngine: + def __init__(self, model_key): + self.model_key = model_key + + async def start(self): + return None + + async def stop(self): + stopped.append(self.model_key) + if self.model_key == "a": + raise RuntimeError("stop boom") + + async def engine_factory(spec): + return FakeEngine(spec.model_key) + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="a", model_name="model-a")) + manager.register_model(ModelSpec(model_key="b", model_name="model-b")) + + await manager.acquire("a") + await manager.release("a") + await manager.acquire("b") + await manager.release("b") + + with pytest.raises(RuntimeError, match="a"): + await manager.shutdown() + + assert stopped == ["a", "b"] + assert manager.get_engine("b") is None + assert manager.get_status("b")["state"] == "unloaded" + + @pytest.mark.anyio + async def test_acquire_retries_if_idle_unload_wins_the_boundary(self, monkeypatch): + """Acquire should not hand back an engine that unloaded in the claim gap.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + created = 0 + stopped = 0 + started = 0 + now = {"value": 1000.0} + + class FakeEngine: + def __init__(self, generation): + self.generation = generation + + async def start(self): + nonlocal started + started += 1 + + async def stop(self): + nonlocal stopped + stopped += 1 + + async def engine_factory(spec): + nonlocal created + created += 1 + return FakeEngine(created) + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: now["value"], + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + original = await manager.acquire("default") + await manager.release("default") + now["value"] += 120.0 + + entered_gap = asyncio.Event() + continue_gap = asyncio.Event() + original_ensure_loaded = manager.ensure_loaded + + async def delayed_ensure_loaded(model_key): + engine = await original_ensure_loaded(model_key) + if not entered_gap.is_set(): + entered_gap.set() + await continue_gap.wait() + return engine + + monkeypatch.setattr(manager, "ensure_loaded", delayed_ensure_loaded) + + acquire_task = asyncio.create_task(manager.acquire("default")) + await entered_gap.wait() + + unloaded = await manager.unload_if_idle("default") + continue_gap.set() + replacement = await acquire_task + + assert unloaded is True + assert stopped == 1 + assert created == 2 + assert started == 2 + assert replacement is not original + assert manager.get_status("default")["state"] == "loaded" + assert manager.get_status("default")["active_requests"] == 1 + + await manager.release("default") + + @pytest.mark.anyio + async def test_failed_unload_keeps_live_engine_tracked(self): + """Unload failure should not orphan a still-live engine.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + created = 0 + now = {"value": 1000.0} + + class FakeEngine: + async def start(self): + return None + + async def stop(self): + raise RuntimeError("boom") + + async def engine_factory(spec): + nonlocal created + created += 1 + return FakeEngine() + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: now["value"], + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + original = await manager.acquire("default") + await manager.release("default") + now["value"] += 120.0 + + unloaded = await manager.unload_if_idle("default") + replacement = await manager.acquire("default") + + assert unloaded is False + assert manager.get_engine("default") is original + assert replacement is original + assert manager.get_status("default")["state"] == "loaded" + assert manager.get_status("default")["last_error"] == "boom" + assert created == 1 + + await manager.release("default") + + @pytest.mark.anyio + async def test_register_model_rejects_replacing_live_resident(self): + """register_model() should not orphan an already loaded resident.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + created = 0 + stopped = 0 + + class FakeEngine: + async def start(self): + return None + + async def stop(self): + nonlocal stopped + stopped += 1 + + async def engine_factory(spec): + nonlocal created + created += 1 + return FakeEngine() + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="first-model")) + + await manager.acquire("default") + await manager.release("default") + + with pytest.raises(RuntimeError, match="Cannot replace resident model"): + manager.register_model( + ModelSpec(model_key="default", model_name="replacement-model") + ) + + await manager.shutdown() + + assert created == 1 + assert stopped == 1 + + @pytest.mark.anyio + async def test_failed_cold_load_cleans_up_partial_engine(self): + """A start() failure should still stop the partially built engine.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + stopped = 0 + + class FakeEngine: + async def start(self): + raise RuntimeError("boom") + + async def stop(self): + nonlocal stopped + stopped += 1 + + async def engine_factory(spec): + return FakeEngine() + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + with pytest.raises(RuntimeError, match="boom"): + await manager.acquire("default") + + assert manager.get_status("default")["state"] == "failed" + assert stopped == 1 + + @pytest.mark.anyio + async def test_cancelled_waiter_does_not_cancel_shared_unload(self): + """A canceled acquire waiter should not cancel the shared unload task.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + now = {"value": 1000.0} + stop_gate = asyncio.Event() + + class FakeEngine: + async def start(self): + return None + + async def stop(self): + await stop_gate.wait() + + async def engine_factory(spec): + return FakeEngine() + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: now["value"], + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + await manager.acquire("default") + await manager.release("default") + now["value"] += 120.0 + + unload_task = asyncio.create_task(manager.unload_if_idle("default")) + await asyncio.sleep(0) + + acquire_task = asyncio.create_task(manager.acquire("default")) + await asyncio.sleep(0) + acquire_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await acquire_task + + stop_gate.set() + unloaded = await unload_task + + assert unloaded is True + assert manager.get_status("default")["state"] == "unloaded" + + @pytest.mark.anyio + async def test_cancelled_unload_waiter_does_not_cancel_shared_unload_task(self): + """Cancelling one unload waiter should not cancel the shared unload operation.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + now = {"value": 1000.0} + stop_gate = asyncio.Event() + + class FakeEngine: + async def start(self): + return None + + async def stop(self): + await stop_gate.wait() + + async def engine_factory(spec): + return FakeEngine() + + manager = ResidencyManager( + engine_factory=engine_factory, + time_fn=lambda: now["value"], + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="default", model_name="test-model")) + + await manager.acquire("default") + await manager.release("default") + now["value"] += 120.0 + + first_waiter = asyncio.create_task(manager.unload_if_idle("default")) + await asyncio.sleep(0) + + second_waiter = asyncio.create_task(manager.unload_if_idle("default")) + await asyncio.sleep(0) + + first_waiter.cancel() + with pytest.raises(asyncio.CancelledError): + await first_waiter + + stop_gate.set() + unloaded = await second_waiter + + assert unloaded is True + assert manager.get_status("default")["state"] == "unloaded" + + +class TestResidencyManagerEdgeCases: + """Edge-case coverage for manager state transitions and error recovery.""" + + @pytest.mark.anyio + async def test_unload_if_idle_when_already_unloaded(self): + """unload_if_idle on an unloaded resident should return False, not corrupt state.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + class FakeEngine: + async def start(self): + pass + + async def stop(self): + pass + + async def factory(spec): + return FakeEngine() + + manager = ResidencyManager( + factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="m", model_name="test")) + + # Never loaded — should be a no-op + result = await manager.unload_if_idle("m") + assert result is False + assert manager.get_status("m")["state"] == "unloaded" + + @pytest.mark.anyio + async def test_unload_if_idle_during_active_load(self): + """unload_if_idle should return False when a load is in progress.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + load_gate = asyncio.Event() + + class FakeEngine: + async def start(self): + await load_gate.wait() + + async def stop(self): + pass + + async def factory(spec): + return FakeEngine() + + manager = ResidencyManager( + factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="m", model_name="test")) + + load_task = asyncio.create_task(manager.ensure_loaded("m")) + await asyncio.sleep(0) # let load start + + result = await manager.unload_if_idle("m") + assert result is False + + load_gate.set() + await load_task + await manager.shutdown() + + @pytest.mark.anyio + async def test_register_model_after_failure_replaces_entry(self): + """A model in FAILED state should be replaceable via register_model.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + call_count = 0 + + class FailEngine: + async def start(self): + raise RuntimeError("boom") + + async def stop(self): + pass + + class GoodEngine: + async def start(self): + pass + + async def stop(self): + pass + + async def factory(spec): + nonlocal call_count + call_count += 1 + if call_count == 1: + return FailEngine() + return GoodEngine() + + manager = ResidencyManager(factory, auto_unload_idle_seconds=0) + spec = ModelSpec(model_key="m", model_name="test") + manager.register_model(spec) + + with pytest.raises(RuntimeError, match="boom"): + await manager.ensure_loaded("m") + assert manager.get_status("m")["state"] == "failed" + + # Re-register should succeed because the model is dormant/failed + manager.register_model(spec) + assert manager.get_status("m")["state"] == "unloaded" + + engine = await manager.ensure_loaded("m") + assert engine is not None + assert manager.get_status("m")["state"] == "loaded" + await manager.shutdown() + + @pytest.mark.anyio + async def test_engine_factory_raises_before_engine_created(self): + """If the factory itself raises (not engine.start), state should be FAILED.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + async def failing_factory(spec): + raise RuntimeError("model not found") + + manager = ResidencyManager(failing_factory, auto_unload_idle_seconds=0) + manager.register_model(ModelSpec(model_key="m", model_name="test")) + + with pytest.raises(RuntimeError, match="model not found"): + await manager.ensure_loaded("m") + + status = manager.get_status("m") + assert status["state"] == "failed" + assert "model not found" in status["last_error"] + + @pytest.mark.anyio + async def test_rapid_acquire_release_refcount_stays_consistent(self): + """Rapid acquire/release cycles should keep active_requests consistent.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + class FakeEngine: + async def start(self): + pass + + async def stop(self): + pass + + async def factory(spec): + return FakeEngine() + + manager = ResidencyManager( + factory, + time_fn=lambda: 1000.0, + auto_unload_idle_seconds=60, + ) + manager.register_model(ModelSpec(model_key="m", model_name="test")) + + # Rapid sequential acquire/release + for _ in range(50): + await manager.acquire("m") + await manager.release("m") + + assert manager.get_status("m")["active_requests"] == 0 + + # Concurrent acquire then sequential release + engines = await asyncio.gather(*[manager.acquire("m") for _ in range(20)]) + assert manager.get_status("m")["active_requests"] == 20 + for _ in engines: + await manager.release("m") + assert manager.get_status("m")["active_requests"] == 0 + + @pytest.mark.anyio + async def test_shutdown_shields_unload_from_cancellation(self): + """Cancelling shutdown() should not orphan a half-stopped engine.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + stop_started = asyncio.Event() + stop_gate = asyncio.Event() + stopped = 0 + + class FakeEngine: + async def start(self): + pass + + async def stop(self): + nonlocal stopped + stop_started.set() + await stop_gate.wait() + stopped += 1 + + async def factory(spec): + return FakeEngine() + + manager = ResidencyManager(factory, auto_unload_idle_seconds=0) + manager.register_model(ModelSpec(model_key="m", model_name="test")) + await manager.ensure_loaded("m") + + shutdown_task = asyncio.create_task(manager.shutdown()) + await stop_started.wait() + + # Cancel shutdown while engine.stop() is in flight + shutdown_task.cancel() + await asyncio.sleep(0) + + # Let engine.stop() complete + stop_gate.set() + with pytest.raises(asyncio.CancelledError): + await shutdown_task + + # The engine should still have been fully stopped + assert stopped == 1 + assert manager.get_status("m")["state"] == "unloaded" + + +class TestSuspendCancellationDedup: + """Verify lifecycle.py uses the shared suspend_cancellation from base.""" + + @pytest.mark.anyio + async def test_residency_manager_uses_shared_suspend_cancellation(self): + """ResidencyManager cleanup paths should work with the shared helper.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + stopped = 0 + + class FakeEngine: + async def start(self): + pass + + async def stop(self): + nonlocal stopped + stopped += 1 + + def prepare_for_start(self): + pass + + async def factory(spec): + return FakeEngine() + + manager = ResidencyManager(factory, auto_unload_idle_seconds=0) + manager.register_model(ModelSpec(model_key="m", model_name="test")) + + engine = await manager.ensure_loaded("m") + assert engine is not None + await manager.shutdown() + assert stopped == 1 diff --git a/tests/test_lifecycle_server.py b/tests/test_lifecycle_server.py new file mode 100644 index 000000000..d5e5e4690 --- /dev/null +++ b/tests/test_lifecycle_server.py @@ -0,0 +1,3628 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Server integration tests for lifecycle / residency behavior.""" + +from __future__ import annotations + +import asyncio +import time +from contextlib import suppress +from types import SimpleNamespace + +import pytest + + +async def _wait_for_resident_state(manager, model_key: str, state: str) -> None: + """Poll until a resident reaches the expected state.""" + while manager.get_status(model_key)["state"] != state: + await asyncio.sleep(0) + + +@pytest.fixture(autouse=True) +def restore_server_globals(): + """Restore mutated server module globals between lifecycle tests.""" + import vllm_mlx.server as srv + + sentinel = object() + global_names = ( + "_engine", + "_model_name", + "_model_path", + "_default_model_key", + "_default_max_tokens", + "_default_timeout", + "_default_temperature", + "_default_top_p", + "_force_mllm_model", + "_auto_unload_idle_seconds", + "_lazy_load_model", + "_residency_manager", + "_lifecycle_task", + "_lifespan_active", + "_mcp_manager", + "_mcp_executor", + "_embedding_engine", + "_embedding_model_locked", + "_api_key", + "_auth_warning_logged", + "_rate_limiter", + "_reasoning_parser", + "_enable_auto_tool_choice", + "_tool_call_parser", + "_tool_parser_instance", + "_idle_unload_enabled", + ) + snapshot = {name: getattr(srv, name, sentinel) for name in global_names} + + yield + + leaked_task = getattr(srv, "_lifecycle_task", None) + original_task = snapshot["_lifecycle_task"] + if ( + leaked_task is not sentinel + and leaked_task is not None + and leaked_task is not original_task + and not leaked_task.done() + ): + leaked_task.cancel() + + for name, value in snapshot.items(): + if value is sentinel: + if hasattr(srv, name): + delattr(srv, name) + else: + setattr(srv, name, value) + + # _idle_unload_enabled is a lazily-created asyncio.Event bound to whatever + # event loop was running when _get_idle_unload_event() was first called. + # Reset to None so the next test gets a fresh Event on its own loop. + srv._idle_unload_enabled = None + + +class TestLifecycleStatusEndpoints: + """Lock in residency metadata surfaced by server status endpoints.""" + + @pytest.mark.anyio + async def test_status_reports_unloaded_resident_metadata(self, monkeypatch): + """Status should surface residency details even when model is unloaded.""" + import vllm_mlx.server as srv + + fake_manager = SimpleNamespace( + get_status=lambda model_key: { + "model_key": model_key, + "state": "unloaded", + "active_requests": 0, + "last_used_at": 1710200000.0, + "loaded_at": None, + "auto_unload_idle_seconds": 300, + } + ) + + monkeypatch.setattr(srv, "_engine", None) + monkeypatch.setattr( + srv, "_model_name", "mlx-community/Qwen3-0.6B-8bit", raising=False + ) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + monkeypatch.setattr(srv, "_residency_manager", fake_manager, raising=False) + + payload = await srv.status() + + assert payload["status"] == "not_loaded" + assert payload["model"] == "mlx-community/Qwen3-0.6B-8bit" + assert payload["residency"]["model_key"] == "default" + assert payload["residency"]["state"] == "unloaded" + assert payload["residency"]["active_requests"] == 0 + assert payload["residency"]["last_used_at"] == 1710200000.0 + assert payload["residency"]["loaded_at"] is None + assert payload["residency"]["auto_unload_idle_seconds"] == 300 + assert payload["requests"] == [] + + @pytest.mark.anyio + async def test_health_exposes_residency_state_for_unloaded_model(self, monkeypatch): + """Health should report lifecycle state, not only a loaded bool.""" + import vllm_mlx.server as srv + + fake_manager = SimpleNamespace( + get_status=lambda model_key: { + "model_key": model_key, + "state": "unloaded", + "active_requests": 0, + "last_used_at": 1710200000.0, + "loaded_at": None, + "auto_unload_idle_seconds": 120, + } + ) + + monkeypatch.setattr(srv, "_engine", None) + monkeypatch.setattr( + srv, + "_model_name", + "mlx-community/Llama-3.2-3B-Instruct-4bit", + raising=False, + ) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + monkeypatch.setattr(srv, "_residency_manager", fake_manager, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None) + + payload = await srv.health() + + assert payload["status"] == "healthy" + assert payload["model_loaded"] is False + assert payload["model_name"] == "mlx-community/Llama-3.2-3B-Instruct-4bit" + assert payload["residency_state"] == "unloaded" + assert payload["active_requests"] == 0 + assert payload["last_used_at"] == 1710200000.0 + assert payload["loaded_at"] is None + assert payload["auto_unload_idle_seconds"] == 120 + + @pytest.mark.anyio + async def test_failed_resident_surfaces_as_unhealthy_and_failed(self, monkeypatch): + """Public status should not leak backend model identity or raw errors.""" + import vllm_mlx.server as srv + + fake_manager = SimpleNamespace( + get_status=lambda model_key: { + "model_key": model_key, + "model_name": "/tmp/private-local-model", + "state": "failed", + "active_requests": 0, + "last_used_at": 1710200000.0, + "loaded_at": None, + "last_error": "reload boom", + "auto_unload_idle_seconds": 120, + } + ) + + monkeypatch.setattr(srv, "_engine", None) + monkeypatch.setattr( + srv, + "_model_name", + "friendly-model", + raising=False, + ) + monkeypatch.setattr( + srv, + "_model_path", + "/tmp/private-local-model", + raising=False, + ) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + monkeypatch.setattr(srv, "_residency_manager", fake_manager, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None) + + health_payload = await srv.health() + status_payload = await srv.status() + + assert health_payload["status"] == "unhealthy" + assert health_payload["model_loaded"] is False + assert health_payload["residency_state"] == "failed" + # /health surfaces a sanitized error category for failed residents + assert health_payload["last_error"] == "model_load_failed" + assert status_payload["status"] == "not_loaded" + assert status_payload["model"] == "friendly-model" + assert status_payload["residency"]["state"] == "failed" + assert status_payload["residency"]["model_name"] == "friendly-model" + # /v1/status surfaces a generic error indicator, not raw exception text + assert status_payload["residency"]["last_error"] == "model_load_failed" + assert status_payload["requests"] == [] + + @pytest.mark.anyio + async def test_health_preserves_mllm_type_when_resident_is_unloaded( + self, monkeypatch + ): + """Unloaded multimodal residents should still report model_type=mllm.""" + import vllm_mlx.server as srv + + fake_manager = SimpleNamespace( + get_status=lambda model_key: { + "model_key": model_key, + "state": "unloaded", + "active_requests": 0, + "last_used_at": 1710200000.0, + "loaded_at": None, + "auto_unload_idle_seconds": 120, + } + ) + + monkeypatch.setattr(srv, "_engine", None) + monkeypatch.setattr( + srv, + "_model_name", + "mlx-community/gemma-3-4b-it-4bit", + raising=False, + ) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + monkeypatch.setattr(srv, "_residency_manager", fake_manager, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None) + + payload = await srv.health() + + assert payload["model_type"] == "mllm" + + @pytest.mark.anyio + async def test_health_uses_model_path_for_unloaded_served_alias_mllm( + self, monkeypatch + ): + """Served aliases should not hide unloaded multimodal model type.""" + import vllm_mlx.server as srv + + fake_manager = SimpleNamespace( + get_status=lambda model_key: { + "model_key": model_key, + "model_name": "mlx-community/gemma-3-4b-it-4bit", + "state": "unloaded", + "active_requests": 0, + "last_used_at": 1710200000.0, + "loaded_at": None, + "auto_unload_idle_seconds": 120, + } + ) + + monkeypatch.setattr(srv, "_engine", None) + monkeypatch.setattr(srv, "_model_name", "prod-chat", raising=False) + monkeypatch.setattr( + srv, + "_model_path", + "mlx-community/gemma-3-4b-it-4bit", + raising=False, + ) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + monkeypatch.setattr(srv, "_residency_manager", fake_manager, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None) + + payload = await srv.health() + + assert payload["model_type"] == "mllm" + + @pytest.mark.anyio + async def test_health_uses_force_mllm_for_unloaded_local_model(self, monkeypatch): + """force_mllm should survive the unloaded-resident health fallback.""" + import vllm_mlx.server as srv + + monkeypatch.setattr(srv, "_mcp_manager", None) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + srv.load_model( + "/tmp/local-model", + force_mllm=True, + auto_unload_idle_seconds=60, + ) + monkeypatch.setattr(srv, "_engine", None) + + payload = await srv.health() + + assert payload["model_type"] == "mllm" + + +class TestCompletionStreamingRelease: + """Verify the completion endpoint releases residency on all paths.""" + + @pytest.mark.anyio + async def test_completion_nonstreaming_error_releases_active_request( + self, monkeypatch + ): + """Non-streaming completion errors must still release the active request.""" + import vllm_mlx.server as srv + + releases = {"count": 0} + acquires = {"count": 0} + + class FakeEngine: + preserve_native_tool_format = False + is_mllm = False + + async def start(self): + pass + + async def stop(self): + pass + + async def generate(self, **kwargs): + raise RuntimeError("generation failed") + + async def fake_acquire( + raw_request, *, total_timeout=None, deadline=None, count_activity=True + ): + acquires["count"] += 1 + return FakeEngine() + + async def fake_release(*, count_activity=True): + releases["count"] += 1 + + monkeypatch.setattr(srv, "_acquire_default_engine_for_request", fake_acquire) + monkeypatch.setattr(srv, "_release_default_engine", fake_release) + monkeypatch.setattr(srv, "_model_name", "test-model") + + class FakeRequest: + async def is_disconnected(self): + return False + + request = SimpleNamespace( + model="test-model", + prompt="hello", + stream=False, + max_tokens=10, + temperature=None, + top_p=None, + top_k=None, + min_p=None, + presence_penalty=None, + repetition_penalty=None, + specprefill=None, + specprefill_keep_pct=None, + stop=None, + timeout=60.0, + ) + + with pytest.raises(RuntimeError, match="generation failed"): + await srv.create_completion(request, FakeRequest()) + + assert acquires["count"] == 1 + assert ( + releases["count"] == 1 + ), "Non-streaming completion must release residency on generation errors" + + @pytest.mark.anyio + async def test_completion_streaming_release_matches_chat_pattern(self, monkeypatch): + """Streaming completion should use try/finally like chat completion does.""" + import vllm_mlx.server as srv + + # The chat completion endpoint uses a release_on_exit flag with try/finally. + # The completion endpoint should follow the same pattern for consistency + # and safety. This test verifies that the streaming path eventually + # calls release via the cleanup callback. + releases = {"count": 0} + + class FakeEngine: + preserve_native_tool_format = False + is_mllm = False + + async def start(self): + pass + + async def stop(self): + pass + + async def stream_generate(self, **kwargs): + yield SimpleNamespace( + text="done", + new_text="done", + finish_reason="stop", + completion_tokens=1, + prompt_tokens=1, + finished=True, + ) + + async def fake_acquire( + raw_request, *, total_timeout=None, deadline=None, count_activity=True + ): + return FakeEngine() + + async def fake_release(*, count_activity=True): + releases["count"] += 1 + + monkeypatch.setattr(srv, "_acquire_default_engine_for_request", fake_acquire) + monkeypatch.setattr(srv, "_release_default_engine", fake_release) + monkeypatch.setattr(srv, "_model_name", "test-model") + + class FakeRequest: + async def is_disconnected(self): + return False + + request = SimpleNamespace( + model="test-model", + prompt="hello", + stream=True, + max_tokens=10, + temperature=None, + top_p=None, + top_k=None, + min_p=None, + presence_penalty=None, + repetition_penalty=None, + specprefill=None, + specprefill_keep_pct=None, + stop=None, + timeout=None, + ) + + response = await srv.create_completion(request, FakeRequest()) + + # Iterate to completion + async for _ in response.body_iterator: + pass + + assert ( + releases["count"] == 1 + ), "Streaming completion must release residency via cleanup callback" + + +class TestStatusEndpointEngineRace: + """Verify status/health endpoints handle engine being None.""" + + @pytest.mark.anyio + async def test_status_endpoint_returns_not_loaded_when_engine_is_none( + self, monkeypatch + ): + """/v1/status should not 500 if engine is unloaded between check and use.""" + import vllm_mlx.server as srv + + call_count = {"n": 0} + + class DisappearingEngine: + """Engine that disappears after the null check.""" + + def get_stats(self): + call_count["n"] += 1 + return { + "running": True, + "uptime_seconds": 10, + "steps_executed": 0, + "num_running": 0, + "num_waiting": 0, + "num_requests_processed": 0, + "total_prompt_tokens": 0, + "total_completion_tokens": 0, + "metal_active_memory_gb": 0, + "metal_peak_memory_gb": 0, + "metal_cache_memory_gb": 0, + "requests": [], + } + + # Set engine to a real object, then unload it mid-call by patching + engine = DisappearingEngine() + monkeypatch.setattr(srv, "_engine", engine) + monkeypatch.setattr(srv, "_model_name", "test") + monkeypatch.setattr(srv, "_residency_manager", None) + monkeypatch.setattr(srv, "_default_model_key", None) + + # Normal case: should work + result = await srv.status() + assert result["status"] == "running" + + # Now simulate the race: engine becomes None after the check + monkeypatch.setattr(srv, "_engine", None) + result = await srv.status() + assert result["status"] == "not_loaded" + + @pytest.mark.anyio + async def test_health_endpoint_handles_engine_none(self, monkeypatch): + """/health should not 500 when engine is None.""" + import vllm_mlx.server as srv + + monkeypatch.setattr(srv, "_engine", None) + monkeypatch.setattr(srv, "_model_name", "test") + monkeypatch.setattr(srv, "_model_path", None) + monkeypatch.setattr(srv, "_force_mllm_model", False) + monkeypatch.setattr(srv, "_mcp_manager", None) + monkeypatch.setattr(srv, "_residency_manager", None) + monkeypatch.setattr(srv, "_default_model_key", None) + + result = await srv.health() + assert result["status"] == "healthy" + assert result["model_loaded"] is False + + +class TestToolParserUsesLocalEngine: + """Tool parser should initialize from the request-local engine.""" + + @pytest.mark.anyio + async def test_chat_completion_initializes_parser_from_acquired_engine( + self, monkeypatch + ): + """Chat completion should seed parser state from the acquired engine.""" + from vllm_mlx.engine.base import GenerationOutput + import vllm_mlx.server as srv + + parser_tokenizers = [] + + class FakeParser: + def __init__(self, tokenizer=None): + parser_tokenizers.append(tokenizer) + + def reset(self): + return None + + def extract_tool_calls(self, output_text, request_dict=None): + return SimpleNamespace( + tools_called=False, + tool_calls=[], + content=output_text, + ) + + class FakeEngine: + preserve_native_tool_format = False + is_mllm = False + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + async def chat(self, **kwargs): + return GenerationOutput( + text="hello", + completion_tokens=1, + prompt_tokens=1, + ) + + local_engine = FakeEngine("tok-local") + + async def fake_acquire( + raw_request, *, total_timeout=None, deadline=None, count_activity=True + ): + return local_engine + + async def fake_release(*, count_activity=True): + return None + + monkeypatch.setattr(srv, "_validate_model_name", lambda _m: None) + monkeypatch.setattr(srv, "_acquire_default_engine_for_request", fake_acquire) + monkeypatch.setattr(srv, "_release_default_engine", fake_release) + monkeypatch.setattr(srv, "_model_name", "served-model") + monkeypatch.setattr(srv, "_default_max_tokens", 32) + monkeypatch.setattr(srv, "_engine", None) + monkeypatch.setattr(srv, "_reasoning_parser", None) + monkeypatch.setattr(srv, "_enable_auto_tool_choice", True) + monkeypatch.setattr(srv, "_tool_call_parser", "fake") + monkeypatch.setattr(srv, "_tool_parser_instance", None) + monkeypatch.setattr( + srv.ToolParserManager, + "get_tool_parser", + lambda name: FakeParser, + ) + + class FakeRawRequest: + async def is_disconnected(self): + return False + + request = srv.ChatCompletionRequest( + model="user-sent-model-name", + messages=[{"role": "user", "content": "hi"}], + stream=False, + tool_choice="auto", + tools=[ + { + "type": "function", + "function": { + "name": "lookup_weather", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + ) + + await srv.create_chat_completion(request, FakeRawRequest()) + + assert parser_tokenizers == ["tok-local"], ( + "Parser init should use the request-local engine acquired for " + "this request, not the stale global _engine" + ) + + +class TestLifecycleFailureHandling: + """Regression coverage for lifecycle failure paths.""" + + @pytest.mark.anyio + async def test_anthropic_validation_error_does_not_acquire_resident( + self, monkeypatch + ): + """Malformed Anthropic payloads should not touch residency at all.""" + from pydantic import ValidationError + + import vllm_mlx.server as srv + + calls = {"acquires": 0, "releases": 0} + + class FakeRequest: + async def json(self): + return {} + + class FakeEngine: + preserve_native_tool_format = False + + async def fake_acquire( + raw_request, *, total_timeout=None, deadline=None, count_activity=True + ): + calls["acquires"] += 1 + return FakeEngine() + + async def fake_release(*, count_activity=True): + calls["releases"] += 1 + + monkeypatch.setattr(srv, "_acquire_default_engine_for_request", fake_acquire) + monkeypatch.setattr(srv, "_release_default_engine", fake_release) + + with pytest.raises(ValidationError): + await srv.create_anthropic_message(FakeRequest()) + + assert calls["acquires"] == 0 + assert calls["releases"] == 0 + + @pytest.mark.anyio + async def test_chat_completion_prep_error_releases_resident(self, monkeypatch): + """Prep failures after acquire should still release chat residency.""" + import vllm_mlx.server as srv + + calls = {"acquires": 0, "releases": 0} + + class FakeEngine: + is_mllm = False + preserve_native_tool_format = False + + async def fake_acquire( + raw_request, *, total_timeout=None, deadline=None, count_activity=True + ): + calls["acquires"] += 1 + return FakeEngine() + + async def fake_release(*, count_activity=True): + calls["releases"] += 1 + + def fake_extract(messages, preserve_native_format): + return ([{"role": "user", "content": "hi"}], [], []) + + def fake_convert_tools(_tools): + raise RuntimeError("boom") + + monkeypatch.setattr(srv, "_acquire_default_engine_for_request", fake_acquire) + monkeypatch.setattr(srv, "_release_default_engine", fake_release) + monkeypatch.setattr(srv, "extract_multimodal_content", fake_extract) + monkeypatch.setattr(srv, "convert_tools_for_template", fake_convert_tools) + + request = SimpleNamespace( + stream=False, + messages=[SimpleNamespace(role="user", content="hi")], + model="mlx-community/Qwen3-0.6B-8bit", + max_tokens=None, + temperature=None, + top_p=None, + top_k=None, + min_p=None, + presence_penalty=None, + repetition_penalty=None, + response_format=None, + tools=[{"type": "function"}], + tool_choice=None, + enable_thinking=None, + video_fps=None, + video_max_frames=None, + specprefill=None, + specprefill_keep_pct=None, + chat_template_kwargs=None, + stop=None, + timeout=None, + ) + + with pytest.raises(RuntimeError, match="boom"): + await srv.create_chat_completion(request, SimpleNamespace()) + + assert calls["acquires"] == 1 + assert calls["releases"] == 1 + + @pytest.mark.anyio + async def test_request_acquire_helper_disconnect_covers_final_lease( + self, monkeypatch + ): + """Disconnects should abort even if residency is stalled in final acquire().""" + import vllm_mlx.server as srv + + acquire_cancelled = asyncio.Event() + lease_gate = asyncio.Event() + + class FakeEngine: + preserve_native_tool_format = False + + class FakeRequest: + async def is_disconnected(self): + return True + + async def fake_acquire(model_key): + try: + await lease_gate.wait() + except asyncio.CancelledError: + acquire_cancelled.set() + raise + return FakeEngine() + + fake_manager = SimpleNamespace(acquire=fake_acquire) + + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", fake_manager, raising=False) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + + total_timeout, deadline = srv._start_request_budget(60.0) + result = await srv._acquire_default_engine_for_request( + FakeRequest(), + total_timeout=total_timeout, + deadline=deadline, + ) + + assert result is None + await asyncio.wait_for(acquire_cancelled.wait(), timeout=1.0) + + @pytest.mark.anyio + async def test_wait_with_disconnect_reports_total_request_timeout(self): + """Timeout details should reflect the configured request budget, not the sub-step.""" + from fastapi import HTTPException + + import vllm_mlx.server as srv + + class FakeRawRequest: + async def json(self): + return { + "model": "mlx-community/Qwen3-0.6B-8bit", + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + "max_tokens": 16, + } + + async def is_disconnected(self): + return False + + with pytest.raises(HTTPException, match="10.0 seconds"): + await srv._wait_with_disconnect( + asyncio.sleep(0.05), + FakeRawRequest(), + timeout=0.01, + timeout_detail_seconds=10.0, + poll_interval=0.001, + ) + + @pytest.mark.anyio + async def test_lifespan_startup_failure_cleans_up_loaded_resident_and_loop( + self, monkeypatch + ): + """Startup failures before yield should not leak lifecycle tasks or loaded residents.""" + import vllm_mlx.server as srv + + stopped = {"count": 0} + + class FakeEngine: + preserve_native_tool_format = False + + async def start(self): + return None + + async def stop(self): + stopped["count"] += 1 + + class FakeRawRequest: + async def json(self): + return { + "model": "mlx-community/Qwen3-0.6B-8bit", + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + "max_tokens": 16, + } + + async def is_disconnected(self): + return False + + async def fake_engine_factory(spec): + return FakeEngine() + + async def fake_init_mcp(config_path): + raise RuntimeError("mcp boom") + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "init_mcp", fake_init_mcp) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + monkeypatch.setenv("VLLM_MLX_MCP_CONFIG", "/tmp/fake-mcp.json") + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60.0, + lazy_load_model=False, + ) + + lifespan = srv.lifespan(srv.app) + try: + with pytest.raises(RuntimeError, match="mcp boom"): + await lifespan.__anext__() + + status = srv._get_lifecycle_status() + assert srv._lifecycle_task is None + assert srv._engine is None + assert status is not None + assert status["state"] == "unloaded" + assert stopped["count"] == 1 + finally: + if srv._lifecycle_task is not None: + srv._lifecycle_task.cancel() + with suppress(asyncio.CancelledError): + await srv._lifecycle_task + srv._lifecycle_task = None + if srv._residency_manager is not None: + with suppress(Exception): + await srv._residency_manager.shutdown() + srv._sync_engine_from_residency() + with suppress(Exception): + await lifespan.aclose() + + @pytest.mark.anyio + async def test_lifespan_startup_failure_preserves_original_exception( + self, monkeypatch, caplog + ): + """Cleanup failures should not replace the original startup exception.""" + import vllm_mlx.server as srv + + class FakeEngine: + preserve_native_tool_format = False + + async def start(self): + return None + + async def stop(self): + raise RuntimeError("stop boom") + + async def fake_engine_factory(spec): + return FakeEngine() + + async def fake_init_mcp(config_path): + raise RuntimeError("mcp boom") + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "init_mcp", fake_init_mcp) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + monkeypatch.setenv("VLLM_MLX_MCP_CONFIG", "/tmp/fake-mcp.json") + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60.0, + lazy_load_model=False, + ) + + lifespan = srv.lifespan(srv.app) + try: + caplog.clear() + with pytest.raises(RuntimeError, match="mcp boom") as excinfo: + await lifespan.__anext__() + assert excinfo.value.__cause__ is None + assert ( + "Lifecycle cleanup failed while preserving the original exception" + in caplog.text + ) + assert "stop boom" in caplog.text + finally: + if srv._lifecycle_task is not None: + srv._lifecycle_task.cancel() + with suppress(asyncio.CancelledError): + await srv._lifecycle_task + srv._lifecycle_task = None + if srv._residency_manager is not None: + with suppress(Exception): + await srv._residency_manager.shutdown() + srv._sync_engine_from_residency() + with suppress(Exception): + await lifespan.aclose() + + @pytest.mark.anyio + async def test_lifespan_startup_failure_keeps_live_runtime_guarded_if_cleanup_fails( + self, monkeypatch + ): + """Startup failures should not orphan a live runtime when cleanup also fails.""" + import vllm_mlx.server as srv + + class FakeEngine: + preserve_native_tool_format = False + + async def start(self): + return None + + async def stop(self): + raise RuntimeError("stop boom") + + async def fake_engine_factory(spec): + return FakeEngine() + + async def fake_init_mcp(config_path): + raise RuntimeError("mcp boom") + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "init_mcp", fake_init_mcp) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + monkeypatch.setenv("VLLM_MLX_MCP_CONFIG", "/tmp/fake-mcp.json") + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60.0, + lazy_load_model=False, + ) + + lifespan = srv.lifespan(srv.app) + try: + with pytest.raises(RuntimeError, match="mcp boom"): + await lifespan.__anext__() + + status = srv._get_lifecycle_status() + assert srv._engine is not None + assert status is not None + assert status["state"] == "loaded" + with pytest.raises(RuntimeError, match="existing residency manager"): + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60.0, + ) + finally: + if srv._lifecycle_task is not None: + srv._lifecycle_task.cancel() + with suppress(asyncio.CancelledError): + await srv._lifecycle_task + srv._lifecycle_task = None + if srv._residency_manager is not None: + with suppress(Exception): + await srv._residency_manager.shutdown() + srv._sync_engine_from_residency() + with suppress(Exception): + await lifespan.aclose() + + @pytest.mark.anyio + async def test_eager_residency_registers_unloaded_resident_before_lifespan_startup( + self, monkeypatch + ): + """Pre-lifespan lifecycle setup should stay unloaded until startup runs.""" + from fastapi import HTTPException + import vllm_mlx.server as srv + + create_calls = {"count": 0} + + class FakeEngine: + async def start(self): + return None + + async def stop(self): + return None + + async def fake_engine_factory(spec): + create_calls["count"] += 1 + return FakeEngine() + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + + try: + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60.0, + lazy_load_model=False, + ) + + lifecycle = srv._get_lifecycle_status() + health_payload = await srv.health() + status_payload = await srv.status() + + assert create_calls["count"] == 0 + assert srv._engine is None + assert srv._residency_manager is not None + assert lifecycle is not None + assert lifecycle["state"] == "unloaded" + assert lifecycle["active_requests"] == 0 + assert health_payload["model_loaded"] is False + assert health_payload["residency_state"] == "unloaded" + assert status_payload["status"] == "not_loaded" + assert status_payload["residency"]["state"] == "unloaded" + + with pytest.raises(HTTPException, match="Model not loaded"): + srv.get_engine() + finally: + if srv._residency_manager is not None: + with suppress(Exception): + await srv._residency_manager.shutdown() + srv._sync_engine_from_residency() + + @pytest.mark.anyio + async def test_load_model_preserves_scheduler_config_when_enabling_residency( + self, monkeypatch + ): + """Residency load should preserve batching config through engine factory.""" + import vllm_mlx.server as srv + + scheduler_config = SimpleNamespace( + enable_mtp=True, + mtp_num_draft_tokens=4, + mtp_optimistic=True, + ) + captured = {} + + class FakeEngine: + async def start(self): + return None + + async def stop(self): + return None + + async def fake_engine_factory(spec): + captured["spec"] = spec + return FakeEngine() + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + + try: + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + use_batching=True, + scheduler_config=scheduler_config, + stream_interval=7, + auto_unload_idle_seconds=60.0, + lazy_load_model=True, + ) + + manager = srv._residency_manager + assert manager is not None + + resident = manager._residents["default"] + assert resident.spec.use_batching is True + assert resident.spec.scheduler_config is scheduler_config + assert resident.spec.stream_interval == 7 + assert manager.auto_unload_idle_seconds == 60.0 + + engine = await srv._acquire_default_engine() + + assert isinstance(engine, FakeEngine) + assert captured["spec"].use_batching is True + assert captured["spec"].scheduler_config is scheduler_config + assert captured["spec"].stream_interval == 7 + assert srv._engine is engine + + await srv._release_default_engine() + finally: + if srv._residency_manager is not None: + with suppress(Exception): + await srv._residency_manager.shutdown() + srv._sync_engine_from_residency() + + @pytest.mark.anyio + async def test_eager_residency_stays_loaded_until_startup_ready(self, monkeypatch): + """Eager residency should not auto-unload before startup reaches readiness.""" + import vllm_mlx.server as srv + + now = {"value": 1000.0} + stopped = {"count": 0} + + class FakeEngine: + preserve_native_tool_format = False + + async def start(self): + return None + + async def stop(self): + stopped["count"] += 1 + + async def fake_engine_factory(spec): + return FakeEngine() + + async def fake_init_mcp(config_path): + now["value"] += 120.0 + await asyncio.sleep(0.05) + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "init_mcp", fake_init_mcp) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + monkeypatch.setenv("VLLM_MLX_MCP_CONFIG", "/tmp/fake-mcp.json") + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=0.02, + lazy_load_model=False, + ) + assert srv._residency_manager is not None + srv._residency_manager._time_fn = lambda: now["value"] + + lifespan = srv.lifespan(srv.app) + try: + await lifespan.__anext__() + + status = srv._get_lifecycle_status() + assert srv._engine is not None + assert status is not None + assert status["state"] == "loaded" + assert stopped["count"] == 0 + finally: + if srv._lifecycle_task is not None: + srv._lifecycle_task.cancel() + with suppress(asyncio.CancelledError): + await srv._lifecycle_task + srv._lifecycle_task = None + if srv._residency_manager is not None: + with suppress(Exception): + await srv._residency_manager.shutdown() + srv._sync_engine_from_residency() + with suppress(Exception): + await lifespan.aclose() + + @pytest.mark.anyio + @pytest.mark.parametrize("engine_kind", ["simple", "batched-llm", "batched-mllm"]) + async def test_eager_engine_start_cancellation_cleans_prepared_state( + self, monkeypatch, engine_kind + ): + """Cancelling eager engine startup should not leave prepared model state behind.""" + import threading + + if engine_kind == "simple": + import vllm_mlx.engine.simple as engine_mod + + engine = engine_mod.SimpleEngine("fake-model") + else: + import vllm_mlx.engine.batched as engine_mod + + engine = engine_mod.BatchedEngine("fake-model") + if engine_kind == "batched-llm": + + async def unexpected_start_llm(): + raise AssertionError("_start_llm should not run after cancellation") + + monkeypatch.setattr(engine, "_is_mllm", False, raising=False) + monkeypatch.setattr( + engine, "_start_llm", unexpected_start_llm, raising=False + ) + else: + + async def unexpected_start_mllm(): + raise AssertionError( + "_start_mllm should not run after cancellation" + ) + + monkeypatch.setattr(engine, "_is_mllm", True, raising=False) + monkeypatch.setattr( + engine, "_start_mllm", unexpected_start_mllm, raising=False + ) + + prepare_entered = threading.Event() + allow_prepare_finish = threading.Event() + prepare_finished = threading.Event() + + def fake_prepare(): + prepare_entered.set() + allow_prepare_finish.wait() + engine._model = object() + if engine_kind == "batched-mllm": + engine._processor = object() + engine._mllm_instance = object() + elif hasattr(engine, "_tokenizer"): + engine._tokenizer = object() + prepare_finished.set() + + monkeypatch.setattr(engine, "prepare_for_start", fake_prepare) + + start_task = asyncio.create_task(engine.start()) + try: + assert await asyncio.wait_for( + asyncio.to_thread(prepare_entered.wait, 1.0), + timeout=1.0, + ) + + start_task.cancel() + allow_prepare_finish.set() + + with pytest.raises(asyncio.CancelledError): + await start_task + + assert await asyncio.wait_for( + asyncio.to_thread(prepare_finished.wait, 1.0), + timeout=1.0, + ) + assert engine._loaded is False + assert engine._model is None + if engine_kind == "batched-mllm": + assert engine._processor is None + assert engine._mllm_instance is None + elif hasattr(engine, "_tokenizer"): + assert engine._tokenizer is None + finally: + allow_prepare_finish.set() + if not start_task.done(): + start_task.cancel() + with suppress(asyncio.CancelledError): + await start_task + with suppress(Exception): + await engine.stop() + + @pytest.mark.anyio + async def test_simple_start_cancellation_preserves_cancelled_error_when_stop_fails( + self, monkeypatch, caplog + ): + """Simple eager startup should still surface cancellation if stop() fails.""" + import threading + + import vllm_mlx.engine.simple as engine_mod + + engine = engine_mod.SimpleEngine("fake-model") + prepare_entered = threading.Event() + allow_prepare_finish = threading.Event() + stop_calls = {"count": 0} + + def fake_prepare(): + prepare_entered.set() + allow_prepare_finish.wait() + engine._model = object() + + async def failing_stop(): + stop_calls["count"] += 1 + raise RuntimeError("stop boom") + + monkeypatch.setattr(engine, "prepare_for_start", fake_prepare) + monkeypatch.setattr(engine, "stop", failing_stop) + + start_task = asyncio.create_task(engine.start()) + try: + assert await asyncio.wait_for( + asyncio.to_thread(prepare_entered.wait, 1.0), + timeout=1.0, + ) + + caplog.set_level("ERROR") + caplog.clear() + start_task.cancel() + allow_prepare_finish.set() + + with pytest.raises(asyncio.CancelledError): + await start_task + + assert stop_calls["count"] == 1 + assert any( + record.levelname == "ERROR" + and "startup cleanup failed while preserving cancellation" + in record.getMessage() + for record in caplog.records + ) + finally: + allow_prepare_finish.set() + if not start_task.done(): + start_task.cancel() + with suppress(asyncio.CancelledError): + await start_task + + @pytest.mark.anyio + async def test_run_blocking_startup_work_waits_for_thread_under_repeated_cancel( + self, + ): + """Repeated cancellation should not return before blocking startup work finishes.""" + import threading + + from vllm_mlx.engine.base import run_blocking_startup_work + + entered = threading.Event() + allow_finish = threading.Event() + finished = threading.Event() + + def blocking_work(): + entered.set() + allow_finish.wait() + # Keep a deterministic post-release window so pre-fix behavior + # can still return cancellation before thread completion. + time.sleep(0.05) + finished.set() + + task = asyncio.create_task(run_blocking_startup_work(blocking_work)) + try: + assert await asyncio.wait_for( + asyncio.to_thread(entered.wait, 1.0), + timeout=1.0, + ) + + task.cancel() + await asyncio.sleep(0) + task.cancel() + + allow_finish.set() + with pytest.raises(asyncio.CancelledError): + await task + + assert finished.is_set() is True + finally: + allow_finish.set() + assert await asyncio.wait_for( + asyncio.to_thread(finished.wait, 1.0), + timeout=1.0, + ) + + @pytest.mark.anyio + async def test_blocking_cache_io_waits_for_thread_under_repeated_cancel( + self, + ): + """Repeated cancellation should not return before cache I/O thread finishes.""" + import threading + + import vllm_mlx.server as srv + + entered = threading.Event() + allow_finish = threading.Event() + finished = threading.Event() + + class FakeEngine: + pass + + def blocking_io(_engine): + entered.set() + allow_finish.wait() + # Keep a deterministic post-release window so pre-fix behavior + # can still return cancellation before thread completion. + time.sleep(0.05) + finished.set() + + task = asyncio.create_task( + srv._run_blocking_engine_cache_io(blocking_io, FakeEngine()) + ) + try: + assert await asyncio.wait_for( + asyncio.to_thread(entered.wait, 1.0), + timeout=1.0, + ) + + task.cancel() + await asyncio.sleep(0) + task.cancel() + + allow_finish.set() + with pytest.raises(asyncio.CancelledError): + await task + + assert finished.is_set() is True + finally: + allow_finish.set() + assert await asyncio.wait_for( + asyncio.to_thread(finished.wait, 1.0), + timeout=1.0, + ) + + @pytest.mark.anyio + async def test_prepare_engine_start_waits_for_thread_under_repeated_cancel(self): + """Repeated cancellation of a residency cold load must not return before + prepare_for_start() finishes, otherwise the thread can keep mutating + model state past request/shutdown boundaries.""" + import threading + + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + entered = threading.Event() + allow_finish = threading.Event() + finished = threading.Event() + + class SlowPrepEngine: + async def start(self): + pass + + async def stop(self): + pass + + def prepare_for_start(self): + entered.set() + allow_finish.wait() + time.sleep(0.05) + finished.set() + + async def factory(spec): + return SlowPrepEngine() + + manager = ResidencyManager(factory, auto_unload_idle_seconds=0) + manager.register_model(ModelSpec(model_key="m", model_name="test")) + + load_task = asyncio.create_task(manager.ensure_loaded("m")) + try: + # Wait for prepare_for_start to enter + assert await asyncio.wait_for( + asyncio.to_thread(entered.wait, 2.0), + timeout=2.0, + ) + + # Double-cancel the load task + load_task.cancel() + await asyncio.sleep(0) + load_task.cancel() + + # Let the blocking work complete + allow_finish.set() + + with pytest.raises(asyncio.CancelledError): + await load_task + + # The critical assertion: prepare_for_start must have fully + # completed before the CancelledError was raised. + assert finished.is_set() is True, ( + "prepare_for_start() was still running after load task returned " + "CancelledError — repeated cancellation escaped the drain loop" + ) + finally: + allow_finish.set() + assert await asyncio.wait_for( + asyncio.to_thread(finished.wait, 2.0), + timeout=2.0, + ) + # Clean up any leftover resident state + with suppress(Exception): + await manager.shutdown() + + @pytest.mark.anyio + async def test_run_blocking_startup_work_does_not_livelock_on_cancelled_inner_task( + self, + ): + """If the inner to_thread task ends up cancelled, the drain loop must + exit instead of spinning forever on CancelledError.""" + from vllm_mlx.engine.base import run_blocking_startup_work + + def work_that_will_be_cancelled(): + raise asyncio.CancelledError() + + task = asyncio.create_task( + run_blocking_startup_work(work_that_will_be_cancelled) + ) + await asyncio.sleep(0) + task.cancel() + + # Must complete promptly — a livelock would hang here + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(task, timeout=2.0) + + @pytest.mark.anyio + async def test_blocking_cache_io_does_not_livelock_on_cancelled_inner_task(self): + """If the inner to_thread task ends cancelled, the drain loop must exit.""" + import vllm_mlx.server as srv + + class FakeEngine: + pass + + def io_that_will_be_cancelled(_engine): + raise asyncio.CancelledError() + + task = asyncio.create_task( + srv._run_blocking_engine_cache_io(io_that_will_be_cancelled, FakeEngine()) + ) + await asyncio.sleep(0) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(task, timeout=2.0) + + @pytest.mark.anyio + async def test_prepare_engine_start_does_not_livelock_on_cancelled_inner_task(self): + """If prepare_for_start's to_thread task ends cancelled, the residency + drain loop must exit instead of spinning.""" + from vllm_mlx.lifecycle import ModelSpec, ResidencyManager + + class CancellingPrepEngine: + async def start(self): + pass + + async def stop(self): + pass + + def prepare_for_start(self): + raise asyncio.CancelledError() + + async def factory(spec): + return CancellingPrepEngine() + + manager = ResidencyManager(factory, auto_unload_idle_seconds=0) + manager.register_model(ModelSpec(model_key="m", model_name="test")) + + load_task = asyncio.create_task(manager.ensure_loaded("m")) + await asyncio.sleep(0) + load_task.cancel() + + # Must complete promptly — a livelock would cause wait_for to raise + # TimeoutError, which we want to surface as a hard failure. + try: + await asyncio.wait_for(load_task, timeout=2.0) + except asyncio.CancelledError: + pass # expected: the load was cancelled + except asyncio.TimeoutError: + pytest.fail("Drain loop livelocked — load_task did not complete within 2s") + + # Clean up + with suppress(Exception): + await manager.shutdown() + + @pytest.mark.anyio + @pytest.mark.parametrize("start_phase", ["llm", "mllm"]) + async def test_batched_start_phase_cancellation_preserves_cancelled_error_when_stop_fails( + self, monkeypatch, caplog, start_phase + ): + """Batched eager startup should keep cancellation primary if teardown fails.""" + import vllm_mlx.engine.batched as engine_mod + + engine = engine_mod.BatchedEngine("fake-model") + start_entered = asyncio.Event() + stop_calls = {"count": 0} + + async def failing_stop(): + stop_calls["count"] += 1 + raise RuntimeError("stop boom") + + async def cancellable_start_phase(): + start_entered.set() + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + raise + + monkeypatch.setattr(engine, "stop", failing_stop) + + if start_phase == "llm": + monkeypatch.setattr(engine, "_is_mllm", False, raising=False) + monkeypatch.setattr(engine, "_model", object(), raising=False) + monkeypatch.setattr(engine, "_tokenizer", object(), raising=False) + monkeypatch.setattr( + engine, "_start_llm", cancellable_start_phase, raising=False + ) + else: + monkeypatch.setattr(engine, "_is_mllm", True, raising=False) + monkeypatch.setattr(engine, "_model", object(), raising=False) + monkeypatch.setattr(engine, "_processor", object(), raising=False) + monkeypatch.setattr(engine, "_mllm_instance", object(), raising=False) + monkeypatch.setattr( + engine, + "_start_mllm", + cancellable_start_phase, + raising=False, + ) + + start_task = asyncio.create_task(engine.start()) + try: + await asyncio.wait_for(start_entered.wait(), timeout=1.0) + + caplog.set_level("ERROR") + caplog.clear() + start_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await start_task + + assert stop_calls["count"] == 1 + assert any( + record.levelname == "ERROR" + and "startup cleanup failed while preserving cancellation" + in record.getMessage() + for record in caplog.records + ) + finally: + if not start_task.done(): + start_task.cancel() + with suppress(asyncio.CancelledError): + await start_task + + @pytest.mark.anyio + async def test_cleanup_failure_does_not_orphan_live_eager_engine(self, monkeypatch): + """A failed eager-engine stop should keep the live engine guarded against replacement.""" + import vllm_mlx.server as srv + + class LiveEngine: + _loaded = True + + async def stop(self): + raise RuntimeError("stop boom") + + live_engine = LiveEngine() + + monkeypatch.setattr(srv, "_engine", live_engine, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + + lifespan = srv.lifespan(srv.app) + try: + await lifespan.__anext__() + + with pytest.raises(RuntimeError, match="stop boom"): + await lifespan.__anext__() + + assert srv._engine is live_engine + with pytest.raises(RuntimeError, match="existing engine"): + srv.load_model("new-model") + finally: + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + with suppress(Exception): + await lifespan.aclose() + + @pytest.mark.anyio + async def test_cleanup_failure_before_residency_shutdown_keeps_live_manager_guarded( + self, monkeypatch + ): + """A pre-shutdown cleanup failure should not orphan a live residency manager.""" + import vllm_mlx.server as srv + + calls = {"shutdown": 0} + + class LiveEngine: + preserve_native_tool_format = False + + live_engine = LiveEngine() + + async def ensure_loaded(model_key): + return live_engine + + async def shutdown(): + calls["shutdown"] += 1 + raise AssertionError("shutdown should not be reached") + + class FakeMCPManager: + async def stop(self): + raise RuntimeError("mcp stop boom") + + live_manager = SimpleNamespace( + ensure_loaded=ensure_loaded, + get_engine=lambda model_key: live_engine, + get_status=lambda model_key: { + "state": "loaded", + "active_requests": 0, + }, + shutdown=shutdown, + ) + + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", live_manager, raising=False) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + monkeypatch.setattr(srv, "_mcp_manager", FakeMCPManager(), raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + monkeypatch.setattr(srv, "_lazy_load_model", False, raising=False) + + lifespan = srv.lifespan(srv.app) + try: + await lifespan.__anext__() + + with pytest.raises(RuntimeError, match="mcp stop boom"): + await lifespan.__anext__() + + assert calls["shutdown"] == 0 + assert srv._residency_manager is live_manager + with pytest.raises(RuntimeError, match="existing residency manager"): + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60, + ) + finally: + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + with suppress(Exception): + await lifespan.aclose() + + @pytest.mark.anyio + async def test_idle_unload_loop_survives_one_unload_failure(self, monkeypatch): + """One unload failure should not kill the background lifecycle loop.""" + import vllm_mlx.server as srv + + retried = asyncio.Event() + calls = {"count": 0} + + class FakeManager: + async def unload_if_idle(self, model_key): + calls["count"] += 1 + if calls["count"] == 1: + raise RuntimeError("boom") + retried.set() + return False + + monkeypatch.setattr(srv, "_residency_manager", FakeManager(), raising=False) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + monkeypatch.setattr(srv, "_auto_unload_idle_seconds", 1.0, raising=False) + monkeypatch.setattr(srv, "_sync_engine_from_residency", lambda: None) + + task = asyncio.create_task(srv._lifecycle_loop()) + try: + await asyncio.wait_for(retried.wait(), timeout=1.5) + finally: + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + assert calls["count"] >= 2 + + @pytest.mark.anyio + async def test_idle_unload_loop_honors_subsecond_timeout_granularity( + self, monkeypatch + ): + """Sub-second unload timeouts should not be rounded up to one-second polling.""" + import vllm_mlx.server as srv + + sleeps = [] + + class FakeManager: + async def unload_if_idle(self, model_key): + return False + + async def fake_sleep(delay): + sleeps.append(delay) + raise asyncio.CancelledError() + + monkeypatch.setattr(srv, "_residency_manager", FakeManager(), raising=False) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + monkeypatch.setattr(srv, "_auto_unload_idle_seconds", 0.5, raising=False) + monkeypatch.setattr(srv, "_sync_engine_from_residency", lambda: None) + monkeypatch.setattr(srv.asyncio, "sleep", fake_sleep) + + with pytest.raises(asyncio.CancelledError): + await srv._lifecycle_loop() + + assert sleeps == [0.25] + + @pytest.mark.anyio + async def test_cache_restore_hook_does_not_block_event_loop(self, monkeypatch): + """Cold-load cache restore should not freeze unrelated loop work.""" + import threading + + import vllm_mlx.server as srv + + callback_fired = threading.Event() + callback_seen_during_hook = {"value": False} + + class FakeEngine: + preserve_native_tool_format = False + + async def start(self): + return None + + async def stop(self): + return None + + def load_cache_from_disk(self, path): + time.sleep(0.2) + callback_seen_during_hook["value"] = callback_fired.is_set() + return 1 + + async def fake_engine_factory(spec): + return FakeEngine() + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60.0, + lazy_load_model=True, + ) + + asyncio.get_running_loop().call_later(0.01, callback_fired.set) + acquire_task = asyncio.create_task(srv._acquire_default_engine()) + + await acquire_task + assert callback_seen_during_hook["value"] is True + await srv._release_default_engine() + + @pytest.mark.anyio + async def test_cache_persist_hook_does_not_block_event_loop(self, monkeypatch): + """Idle-unload cache persistence should not freeze unrelated loop work.""" + import threading + + import vllm_mlx.server as srv + + now = {"value": 1000.0} + callback_fired = threading.Event() + callback_seen_during_hook = {"value": False} + + class FakeEngine: + async def start(self): + return None + + async def stop(self): + return None + + def load_cache_from_disk(self, path): + return 0 + + def save_cache_to_disk(self, path): + time.sleep(0.2) + callback_seen_during_hook["value"] = callback_fired.is_set() + return True + + async def fake_engine_factory(spec): + return FakeEngine() + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60.0, + lazy_load_model=True, + ) + + manager = srv._residency_manager + assert manager is not None + monkeypatch.setattr(manager, "_time_fn", lambda: now["value"]) + + await srv._acquire_default_engine() + await srv._release_default_engine() + + now["value"] += 120.0 + asyncio.get_running_loop().call_later(0.01, callback_fired.set) + unload_task = asyncio.create_task(manager.unload_if_idle("default")) + + unloaded = await unload_task + assert unloaded is True + assert callback_seen_during_hook["value"] is True + + @pytest.mark.anyio + async def test_idle_unload_persists_and_restores_prefix_cache(self, monkeypatch): + """Server-managed idle unload should save and reload prefix cache state.""" + import vllm_mlx.server as srv + + load_calls = {"count": 0} + save_calls = {"count": 0} + now = {"value": 1000.0} + + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + + class FakeEngine: + async def start(self): + return None + + async def stop(self): + return None + + def load_cache_from_disk(self, path): + return 1 + + def save_cache_to_disk(self, path): + return True + + async def fake_engine_factory(spec): + return FakeEngine() + + def fake_load_prefix_cache_from_disk(engine=None): + load_calls["count"] += 1 + + def fake_save_prefix_cache_to_disk(engine=None): + save_calls["count"] += 1 + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr( + srv, "_load_prefix_cache_from_disk", fake_load_prefix_cache_from_disk + ) + monkeypatch.setattr( + srv, "_save_prefix_cache_to_disk", fake_save_prefix_cache_to_disk + ) + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60, + ) + + manager = srv._residency_manager + assert manager is not None + monkeypatch.setattr(manager, "_time_fn", lambda: now["value"]) + + await manager.acquire("default") + await manager.release("default") + now["value"] += 120.0 + await manager.unload_if_idle("default") + await manager.acquire("default") + + assert load_calls["count"] == 2 + assert save_calls["count"] == 1 + + await manager.release("default") + + @pytest.mark.anyio + async def test_lifespan_does_not_double_apply_cache_hooks_in_lifecycle_mode( + self, monkeypatch + ): + """Lifecycle startup/shutdown should not duplicate cache load/save hooks.""" + import vllm_mlx.server as srv + + load_calls = {"count": 0} + save_calls = {"count": 0} + + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + + class FakeEngine: + async def start(self): + return None + + async def stop(self): + return None + + def load_cache_from_disk(self, path): + load_calls["count"] += 1 + return 1 + + def save_cache_to_disk(self, path): + save_calls["count"] += 1 + return True + + async def fake_engine_factory(spec): + return FakeEngine() + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "_mcp_manager", None) + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60, + ) + + lifespan = srv.lifespan(srv.app) + await lifespan.__anext__() + assert load_calls["count"] == 1 + + with pytest.raises(StopAsyncIteration): + await lifespan.__anext__() + + assert save_calls["count"] == 1 + + def test_residency_reload_invalidates_cached_tool_parser(self, monkeypatch): + """Tool parser instances should be rebuilt after an unload/reload engine swap.""" + import vllm_mlx.server as srv + + parser_tokenizers = [] + + class FakeParser: + def __init__(self, tokenizer=None): + parser_tokenizers.append(tokenizer) + + @classmethod + def supports_native_format(cls): + return False + + def reset(self): + return None + + def extract_tool_calls(self, output_text, request_dict=None): + return SimpleNamespace( + tools_called=False, + tool_calls=[], + content=output_text, + ) + + class FakeEngine: + def __init__(self, tokenizer): + self._tokenizer_value = tokenizer + self.preserve_native_tool_format = False + + @property + def tokenizer(self): + return self._tokenizer_value + + engine_state = {"engine": FakeEngine("tok-1")} + fake_manager = SimpleNamespace( + get_engine=lambda model_key: engine_state["engine"] + ) + + monkeypatch.setattr(srv, "_enable_auto_tool_choice", True, raising=False) + monkeypatch.setattr(srv, "_tool_call_parser", "fake", raising=False) + monkeypatch.setattr(srv, "_tool_parser_instance", None, raising=False) + monkeypatch.setattr(srv, "_engine", engine_state["engine"], raising=False) + monkeypatch.setattr(srv, "_residency_manager", fake_manager, raising=False) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + monkeypatch.setattr( + srv.ToolParserManager, + "get_tool_parser", + lambda name: FakeParser, + ) + + srv._parse_tool_calls_with_parser("") + assert parser_tokenizers == ["tok-1"] + + engine_state["engine"] = None + srv._sync_engine_from_residency() + engine_state["engine"] = FakeEngine("tok-2") + srv._sync_engine_from_residency() + + srv._parse_tool_calls_with_parser("") + assert parser_tokenizers == ["tok-1", "tok-2"] + + @pytest.mark.anyio + async def test_acquire_path_rebuilds_tool_parser_after_resident_swap( + self, monkeypatch + ): + """The first request path should rebuild parser state when acquire returns a new engine.""" + import vllm_mlx.server as srv + + parser_tokenizers = [] + + class FakeParser: + def __init__(self, tokenizer=None): + parser_tokenizers.append(tokenizer) + + @classmethod + def supports_native_format(cls): + return False + + def reset(self): + return None + + def extract_tool_calls(self, output_text, request_dict=None): + return SimpleNamespace( + tools_called=False, + tool_calls=[], + content=output_text, + ) + + class FakeEngine: + def __init__(self, tokenizer): + self._tokenizer_value = tokenizer + self.preserve_native_tool_format = False + + @property + def tokenizer(self): + return self._tokenizer_value + + current = {"engine": FakeEngine("tok-1")} + + async def fake_acquire(model_key): + return current["engine"] + + async def fake_release(model_key): + return None + + fake_manager = SimpleNamespace( + acquire=fake_acquire, + release=fake_release, + get_engine=lambda model_key: current["engine"], + ) + + monkeypatch.setattr(srv, "_enable_auto_tool_choice", True, raising=False) + monkeypatch.setattr(srv, "_tool_call_parser", "fake", raising=False) + monkeypatch.setattr(srv, "_tool_parser_instance", None, raising=False) + monkeypatch.setattr(srv, "_engine", current["engine"], raising=False) + monkeypatch.setattr(srv, "_residency_manager", fake_manager, raising=False) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + monkeypatch.setattr( + srv.ToolParserManager, + "get_tool_parser", + lambda name: FakeParser, + ) + + srv._parse_tool_calls_with_parser("") + assert parser_tokenizers == ["tok-1"] + + current["engine"] = FakeEngine("tok-2") + await srv._acquire_default_engine() + srv._parse_tool_calls_with_parser("") + + assert parser_tokenizers == ["tok-1", "tok-2"] + + await srv._release_default_engine() + + @pytest.mark.anyio + async def test_shutdown_clears_stopped_eager_engine_for_inprocess_reload( + self, monkeypatch + ): + """A stopped eager engine should not block the next in-process load_model().""" + import vllm_mlx.engine.simple as simple_mod + import vllm_mlx.server as srv + + class OldEngine: + _loaded = True + + def __init__(self): + self.stopped = False + + async def stop(self): + self.stopped = True + + old_engine = OldEngine() + + class NewEngine: + def __init__(self): + self.is_mllm = False + self.preserve_native_tool_format = False + + async def start(self): + return None + + new_engine = NewEngine() + + monkeypatch.setattr(srv, "_engine", old_engine, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr( + simple_mod, + "SimpleEngine", + lambda model_name, **kwargs: new_engine, + ) + + lifespan = srv.lifespan(srv.app) + await lifespan.__anext__() + + with pytest.raises(StopAsyncIteration): + await lifespan.__anext__() + + assert old_engine.stopped is True + + await asyncio.to_thread( + srv.load_model, + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=0.0, + ) + + assert srv._engine is new_engine + + def test_load_model_closes_temporary_loop_after_eager_simple_start( + self, monkeypatch + ): + """In-process eager simple startup should close its temporary event loop.""" + import vllm_mlx.engine.simple as simple_mod + import vllm_mlx.server as srv + + created_loop = {"loop": None} + observed_loop = {"loop": None} + real_new_event_loop = asyncio.new_event_loop + previous_loop = real_new_event_loop() + + class FakeEngine: + is_mllm = False + preserve_native_tool_format = False + + async def start(self): + await asyncio.to_thread(lambda: None) + + async def stop(self): + return None + + def tracked_new_event_loop(): + loop = real_new_event_loop() + created_loop["loop"] = loop + return loop + + monkeypatch.setattr(asyncio, "new_event_loop", tracked_new_event_loop) + monkeypatch.setattr( + simple_mod, + "SimpleEngine", + lambda model_name, **kwargs: FakeEngine(), + ) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + + try: + asyncio.set_event_loop(previous_loop) + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=0.0, + ) + + assert created_loop["loop"] is not None + assert created_loop["loop"].is_closed() is True + try: + observed_loop["loop"] = asyncio.get_event_loop() + except RuntimeError: + observed_loop["loop"] = None + if observed_loop["loop"] is not None: + assert observed_loop["loop"] is not created_loop["loop"] + assert observed_loop["loop"].is_closed() is False + finally: + asyncio.set_event_loop(None) + if ( + observed_loop["loop"] is not None + and observed_loop["loop"] is not previous_loop + and observed_loop["loop"] is not created_loop["loop"] + and not observed_loop["loop"].is_closed() + ): + observed_loop["loop"].close() + if not previous_loop.is_closed(): + previous_loop.close() + if ( + created_loop["loop"] is not None + and not created_loop["loop"].is_closed() + ): + created_loop["loop"].close() + monkeypatch.setattr(srv, "_engine", None, raising=False) + + @pytest.mark.anyio + async def test_load_model_rejects_reconfiguration_after_lifespan_start( + self, monkeypatch + ): + """Post-start reconfiguration should require a server restart and be a no-op.""" + import vllm_mlx.server as srv + + async def fake_engine_factory(spec): + raise AssertionError("lazy startup should not load during lifespan entry") + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=0.0, + lazy_load_model=True, + ) + + lifespan = srv.lifespan(srv.app) + await lifespan.__anext__() + original_engine = srv._engine + original_manager = srv._residency_manager + original_model_name = srv._model_name + original_default_model_key = srv._default_model_key + original_auto_unload = srv._auto_unload_idle_seconds + original_lazy_load = srv._lazy_load_model + original_parser = srv._tool_parser_instance + + try: + with pytest.raises(RuntimeError, match="restart the server"): + srv.load_model( + "mlx-community/Llama-3.2-3B-Instruct-4bit", + auto_unload_idle_seconds=0.0, + lazy_load_model=True, + ) + + assert srv._engine is original_engine + assert srv._residency_manager is original_manager + assert srv._model_name == original_model_name + assert srv._default_model_key == original_default_model_key + assert srv._auto_unload_idle_seconds == original_auto_unload + assert srv._lazy_load_model is original_lazy_load + assert srv._tool_parser_instance is original_parser + finally: + with pytest.raises(StopAsyncIteration): + await lifespan.__anext__() + + @pytest.mark.anyio + async def test_lazy_cold_acquire_does_not_block_event_loop(self, monkeypatch): + """Cold resident startup should not freeze unrelated event-loop work.""" + import vllm_mlx.server as srv + + class FakeEngine: + preserve_native_tool_format = False + + def prepare_for_start(self): + time.sleep(0.2) + + async def start(self): + return None + + async def stop(self): + return None + + async def fake_engine_factory(spec): + return FakeEngine() + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60.0, + lazy_load_model=True, + ) + + acquire_task = asyncio.create_task(srv._acquire_default_engine()) + + async def heartbeat(): + await asyncio.sleep(0) + return "heartbeat" + + await asyncio.sleep(0) + heartbeat_task = asyncio.create_task(heartbeat()) + assert await heartbeat_task == "heartbeat" + assert not acquire_task.done() + + await acquire_task + + await srv._release_default_engine() + + @pytest.mark.anyio + async def test_completion_timeout_covers_cold_resident_acquire(self, monkeypatch): + """Request timeout should include lazy-load engine acquisition.""" + from fastapi import HTTPException + + import vllm_mlx.server as srv + + generate_calls = {"count": 0} + load_gate = asyncio.Event() + + class FakeEngine: + preserve_native_tool_format = False + + async def start(self): + await load_gate.wait() + + async def stop(self): + return None + + async def generate(self, **kwargs): + generate_calls["count"] += 1 + return SimpleNamespace( + text="done", + finish_reason="stop", + completion_tokens=1, + prompt_tokens=1, + ) + + class FakeRawRequest: + async def json(self): + return { + "model": "mlx-community/Qwen3-0.6B-8bit", + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + "max_tokens": 16, + } + + async def is_disconnected(self): + return False + + async def fake_engine_factory(spec): + return FakeEngine() + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60.0, + lazy_load_model=True, + ) + + request = SimpleNamespace( + model="mlx-community/Qwen3-0.6B-8bit", + prompt="hi", + stream=False, + max_tokens=None, + temperature=None, + top_p=None, + top_k=None, + min_p=None, + presence_penalty=None, + repetition_penalty=None, + specprefill=None, + specprefill_keep_pct=None, + stop=None, + timeout=0.01, + ) + + request_task = asyncio.create_task( + srv.create_completion(request, FakeRawRequest()) + ) + try: + done, _ = await asyncio.wait({request_task}, timeout=0.2) + assert request_task in done + + with pytest.raises(HTTPException, match="Request timed out"): + await request_task + + assert generate_calls["count"] == 0 + + load_gate.set() + engine = await asyncio.wait_for(srv._acquire_default_engine(), timeout=1.0) + await srv._release_default_engine() + + status = srv._get_lifecycle_status() + assert engine is not None + assert status is not None + assert status["state"] == "loaded" + assert status["active_requests"] == 0 + finally: + load_gate.set() + if not request_task.done(): + with suppress(Exception): + await request_task + + @pytest.mark.anyio + async def test_chat_timeout_covers_cold_resident_acquire(self, monkeypatch): + """Chat timeout should include lazy-load engine acquisition.""" + from fastapi import HTTPException + + import vllm_mlx.server as srv + + chat_calls = {"count": 0} + load_gate = asyncio.Event() + + class FakeEngine: + is_mllm = False + preserve_native_tool_format = False + + async def start(self): + await load_gate.wait() + + async def stop(self): + return None + + async def chat(self, **kwargs): + chat_calls["count"] += 1 + return SimpleNamespace( + text="done", + finish_reason="stop", + completion_tokens=1, + prompt_tokens=1, + ) + + class FakeRawRequest: + async def json(self): + return { + "model": "mlx-community/Qwen3-0.6B-8bit", + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + "max_tokens": 16, + } + + async def is_disconnected(self): + return False + + def fake_extract(messages, preserve_native_format): + return ([{"role": "user", "content": "hi"}], [], []) + + async def fake_engine_factory(spec): + return FakeEngine() + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "extract_multimodal_content", fake_extract) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60.0, + lazy_load_model=True, + ) + + request = SimpleNamespace( + model="mlx-community/Qwen3-0.6B-8bit", + messages=[SimpleNamespace(role="user", content="hi")], + stream=False, + max_tokens=None, + temperature=None, + top_p=None, + top_k=None, + min_p=None, + presence_penalty=None, + repetition_penalty=None, + response_format=None, + tools=None, + tool_choice=None, + enable_thinking=None, + video_fps=None, + video_max_frames=None, + specprefill=None, + specprefill_keep_pct=None, + chat_template_kwargs=None, + timeout=0.01, + ) + + request_task = asyncio.create_task( + srv.create_chat_completion(request, FakeRawRequest()) + ) + try: + done, _ = await asyncio.wait({request_task}, timeout=0.2) + assert request_task in done + + with pytest.raises(HTTPException, match="Request timed out"): + await request_task + + assert chat_calls["count"] == 0 + + load_gate.set() + engine = await asyncio.wait_for(srv._acquire_default_engine(), timeout=1.0) + await srv._release_default_engine() + + status = srv._get_lifecycle_status() + assert engine is not None + assert status is not None + assert status["state"] == "loaded" + assert status["active_requests"] == 0 + finally: + load_gate.set() + if not request_task.done(): + with suppress(Exception): + await request_task + + @pytest.mark.anyio + async def test_completion_disconnect_covers_cold_resident_acquire( + self, monkeypatch + ): + """Disconnect handling should abort a cold resident acquire before generation.""" + from fastapi.responses import Response + + import vllm_mlx.server as srv + + generate_calls = {"count": 0} + load_gate = asyncio.Event() + disconnect_polled = asyncio.Event() + + class FakeEngine: + preserve_native_tool_format = False + + async def start(self): + await load_gate.wait() + + async def stop(self): + return None + + async def generate(self, **kwargs): + generate_calls["count"] += 1 + return SimpleNamespace( + text="done", + finish_reason="stop", + completion_tokens=1, + prompt_tokens=1, + ) + + class FakeRequest: + async def is_disconnected(self): + disconnect_polled.set() + return True + + async def fake_engine_factory(spec): + return FakeEngine() + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60.0, + lazy_load_model=True, + ) + + request = SimpleNamespace( + model="mlx-community/Qwen3-0.6B-8bit", + prompt="hi", + stream=False, + max_tokens=None, + temperature=None, + top_p=None, + top_k=None, + min_p=None, + presence_penalty=None, + repetition_penalty=None, + specprefill=None, + specprefill_keep_pct=None, + stop=None, + timeout=60.0, + ) + + request_task = asyncio.create_task( + srv.create_completion(request, FakeRequest()) + ) + try: + # Leave generous slack over the production 0.5s poll interval so + # this stays a behavior test rather than a scheduler-jitter race. + await asyncio.wait_for(disconnect_polled.wait(), timeout=2.0) + + done, _ = await asyncio.wait({request_task}, timeout=1.0) + assert request_task in done + + response = await request_task + + assert isinstance(response, Response) + assert response.status_code == 499 + assert generate_calls["count"] == 0 + + load_gate.set() + engine = await asyncio.wait_for(srv._acquire_default_engine(), timeout=1.0) + await srv._release_default_engine() + + status = srv._get_lifecycle_status() + assert engine is not None + assert status is not None + assert status["state"] == "loaded" + assert status["active_requests"] == 0 + finally: + load_gate.set() + if not request_task.done(): + with suppress(Exception): + await request_task + + @pytest.mark.anyio + async def test_count_tokens_disconnect_covers_cold_resident_acquire( + self, monkeypatch + ): + """Token counting should unwind a solo cold load after a client disconnect.""" + from fastapi.responses import Response + + import vllm_mlx.server as srv + + created = 0 + encode_calls = {"count": 0} + disconnect_polled = asyncio.Event() + first_load_gate = asyncio.Event() + first_start_cancelled = asyncio.Event() + stopped_generations: list[int] = [] + + class FakeTokenizer: + def encode(self, text): + encode_calls["count"] += 1 + return list(range(len(text))) + + class FakeEngine: + def __init__(self, generation): + self.generation = generation + + preserve_native_tool_format = False + + async def start(self): + if self.generation != 1: + return None + try: + await first_load_gate.wait() + except asyncio.CancelledError: + first_start_cancelled.set() + raise + + async def stop(self): + stopped_generations.append(self.generation) + return None + + @property + def tokenizer(self): + return FakeTokenizer() + + class FakeRequest: + async def json(self): + return { + "system": "sys", + "messages": [{"role": "user", "content": "hi"}], + } + + async def is_disconnected(self): + disconnect_polled.set() + return True + + async def fake_engine_factory(spec): + nonlocal created + created += 1 + return FakeEngine(created) + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60.0, + lazy_load_model=True, + ) + + request_task = asyncio.create_task(srv.count_anthropic_tokens(FakeRequest())) + try: + await asyncio.wait_for(disconnect_polled.wait(), timeout=2.0) + + done, _ = await asyncio.wait({request_task}, timeout=1.0) + assert request_task in done + + response = await request_task + + assert isinstance(response, Response) + assert response.status_code == 499 + assert encode_calls["count"] == 0 + + await asyncio.wait_for(first_start_cancelled.wait(), timeout=1.0) + + status = srv._get_lifecycle_status() + assert status is not None + assert status["state"] == "unloaded" + assert status["active_requests"] == 0 + assert stopped_generations == [1] + + engine = await asyncio.wait_for(srv._acquire_default_engine(), timeout=1.0) + await srv._release_default_engine() + + assert engine is not None + assert engine.generation == 2 + assert created == 2 + status = srv._get_lifecycle_status() + assert status["state"] == "loaded" + assert status["active_requests"] == 0 + finally: + first_load_gate.set() + if not request_task.done(): + with suppress(Exception): + await request_task + + @pytest.mark.anyio + async def test_count_tokens_does_not_refresh_idle_unload_activity( + self, monkeypatch + ): + """Budgeting-only traffic should not keep the resident hot indefinitely.""" + import vllm_mlx.server as srv + + now = {"value": 1000.0} + + class FakeTokenizer: + def encode(self, text): + return list(range(len(text))) + + class FakeEngine: + preserve_native_tool_format = False + + async def start(self): + return None + + async def stop(self): + return None + + @property + def tokenizer(self): + return FakeTokenizer() + + class FakeRequest: + async def json(self): + return { + "system": "sys", + "messages": [{"role": "user", "content": "hi"}], + } + + async def is_disconnected(self): + return False + + async def fake_engine_factory(spec): + return FakeEngine() + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60.0, + lazy_load_model=True, + ) + + manager = srv._residency_manager + assert manager is not None + monkeypatch.setattr(manager, "_time_fn", lambda: now["value"]) + + await srv._acquire_default_engine() + await srv._release_default_engine() + + now["value"] = 1059.0 + response = await srv.count_anthropic_tokens(FakeRequest()) + assert response == {"input_tokens": 5} + + now["value"] = 1061.0 + unloaded = await manager.unload_if_idle("default") + status = srv._get_lifecycle_status() + + assert unloaded is True + assert status is not None + assert status["state"] == "unloaded" + + @pytest.mark.anyio + async def test_count_tokens_validates_model_before_resident_acquire( + self, monkeypatch + ): + """count_tokens should reject wrong model names before cold resident acquire.""" + from fastapi import HTTPException + + import vllm_mlx.server as srv + + calls = {"acquire": 0} + + class FakeRequest: + async def json(self): + return { + "model": "wrong-model", + "system": "sys", + "messages": [{"role": "user", "content": "hi"}], + } + + async def is_disconnected(self): + return False + + async def fake_acquire_default_engine_for_request(*args, **kwargs): + calls["acquire"] += 1 + raise AssertionError("acquire should not run for wrong-model count_tokens") + + monkeypatch.setattr( + srv, + "_acquire_default_engine_for_request", + fake_acquire_default_engine_for_request, + ) + monkeypatch.setattr( + srv, + "_model_name", + "mlx-community/Qwen3-0.6B-8bit", + raising=False, + ) + + with pytest.raises(HTTPException, match="does not exist"): + await srv.count_anthropic_tokens(FakeRequest()) + + assert calls["acquire"] == 0 + + @pytest.mark.anyio + async def test_anthropic_messages_refresh_idle_unload_activity(self, monkeypatch): + """Successful Anthropic messages requests should count as residency activity.""" + import vllm_mlx.server as srv + + now = {"value": 1000.0} + + class FakeOutput: + text = "hello" + completion_tokens = 5 + prompt_tokens = 3 + finish_reason = "stop" + + class FakeEngine: + preserve_native_tool_format = False + + async def start(self): + return None + + async def stop(self): + return None + + async def chat(self, messages, **kwargs): + return FakeOutput() + + class FakeRawRequest: + async def json(self): + return { + "model": "mlx-community/Qwen3-0.6B-8bit", + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + "max_tokens": 16, + } + + async def is_disconnected(self): + return False + + async def fake_engine_factory(spec): + return FakeEngine() + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + monkeypatch.setattr(srv, "_lifespan_active", False, raising=False) + monkeypatch.setattr( + srv, "_parse_tool_calls_with_parser", lambda text, request: (text, None) + ) + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60.0, + lazy_load_model=True, + ) + + manager = srv._residency_manager + assert manager is not None + monkeypatch.setattr(manager, "_time_fn", lambda: now["value"]) + + await srv._acquire_default_engine() + await srv._release_default_engine() + status = manager.get_status("default") + assert status["last_used_at"] == 1000.0 + + now["value"] = 1059.0 + response = await srv.create_anthropic_message(FakeRawRequest()) + assert response.status_code == 200 + + status = manager.get_status("default") + assert status["last_used_at"] == 1059.0 + + now["value"] = 1061.0 + unloaded = await manager.unload_if_idle("default") + status = srv._get_lifecycle_status() + + assert unloaded is False + assert status is not None + assert status["state"] == "loaded" + + @pytest.mark.anyio + async def test_wait_with_disconnect_treats_raced_task_cancellation_as_disconnect( + self, + ): + """A raced cancelled task should not leak CancelledError past disconnect handling.""" + import vllm_mlx.server as srv + + task_ref = {"task": None} + + class FakeRequest: + async def is_disconnected(self): + task = task_ref["task"] + assert task is not None + task.cancel() + await asyncio.sleep(0) + return True + + async def cancellable_work(): + await asyncio.sleep(3600) + + task = asyncio.create_task(cancellable_work()) + task_ref["task"] = task + + result = await srv._wait_with_disconnect( + task, + FakeRequest(), + timeout=1.0, + poll_interval=0.001, + ) + + assert result is None + assert task.cancelled() is True + + @pytest.mark.anyio + async def test_lazy_load_model_starts_unloaded_and_reports_unloaded_status( + self, monkeypatch + ): + """Lazy lifecycle mode should register an unloaded resident before first request.""" + import vllm_mlx.server as srv + + create_calls = {"count": 0} + + class FakeEngine: + async def start(self): + return None + + async def stop(self): + return None + + async def fake_engine_factory(spec): + create_calls["count"] += 1 + return FakeEngine() + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=0.0, + lazy_load_model=True, + ) + + lifespan = srv.lifespan(srv.app) + await lifespan.__anext__() + + assert create_calls["count"] == 0 + assert srv._engine is None + + health_payload = await srv.health() + status_payload = await srv.status() + + assert health_payload["model_loaded"] is False + assert health_payload["residency_state"] == "unloaded" + assert status_payload["status"] == "not_loaded" + assert status_payload["residency"]["state"] == "unloaded" + + with pytest.raises(StopAsyncIteration): + await lifespan.__anext__() + + @pytest.mark.anyio + async def test_lazy_load_first_acquire_triggers_initial_engine_load( + self, monkeypatch + ): + """The first real acquire after lazy startup should create and start the engine.""" + import vllm_mlx.server as srv + + create_calls = {"count": 0} + start_calls = {"count": 0} + + class FakeEngine: + preserve_native_tool_format = False + is_mllm = False + + async def start(self): + start_calls["count"] += 1 + + async def stop(self): + return None + + def get_stats(self): + return {} + + async def fake_engine_factory(spec): + create_calls["count"] += 1 + return FakeEngine() + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=0.0, + lazy_load_model=True, + ) + + lifespan = srv.lifespan(srv.app) + await lifespan.__anext__() + + assert create_calls["count"] == 0 + assert start_calls["count"] == 0 + + engine = await srv._acquire_default_engine() + status_payload = await srv.status() + + assert create_calls["count"] == 1 + assert start_calls["count"] == 1 + assert engine is srv._engine + assert status_payload["residency"]["state"] == "loaded" + assert status_payload["status"] == "stopped" + + await srv._release_default_engine() + + with pytest.raises(StopAsyncIteration): + await lifespan.__anext__() + + @pytest.mark.anyio + async def test_lazy_load_with_idle_unload_starts_unloaded_and_reports_unloaded_status( + self, monkeypatch + ): + """Lazy startup should stay cold even when idle auto-unload is also enabled.""" + import vllm_mlx.server as srv + + create_calls = {"count": 0} + + class FakeEngine: + async def start(self): + return None + + async def stop(self): + return None + + async def fake_engine_factory(spec): + create_calls["count"] += 1 + return FakeEngine() + + monkeypatch.setattr(srv, "_engine_factory", fake_engine_factory) + monkeypatch.setattr(srv, "_engine", None, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None, raising=False) + monkeypatch.setattr(srv, "_lifecycle_task", None, raising=False) + + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60.0, + lazy_load_model=True, + ) + + lifespan = srv.lifespan(srv.app) + await lifespan.__anext__() + + assert create_calls["count"] == 0 + assert srv._engine is None + + health_payload = await srv.health() + status_payload = await srv.status() + + assert health_payload["model_loaded"] is False + assert health_payload["residency_state"] == "unloaded" + assert health_payload["auto_unload_idle_seconds"] == 60.0 + assert status_payload["status"] == "not_loaded" + assert status_payload["residency"]["state"] == "unloaded" + assert status_payload["residency"]["auto_unload_idle_seconds"] == 60.0 + + with pytest.raises(StopAsyncIteration): + await lifespan.__anext__() + + def test_load_model_rejects_replacing_live_residency_manager(self, monkeypatch): + """load_model() should not overwrite a live residency manager in-process.""" + import vllm_mlx.server as srv + + live_engine = object() + live_manager = SimpleNamespace( + get_engine=lambda model_key: live_engine, + get_status=lambda model_key: { + "state": "loaded", + "active_requests": 0, + }, + ) + + monkeypatch.setattr(srv, "_residency_manager", live_manager, raising=False) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + + with pytest.raises(RuntimeError, match="existing residency manager"): + srv.load_model( + "mlx-community/Qwen3-0.6B-8bit", + auto_unload_idle_seconds=60, + ) + + assert srv._residency_manager is live_manager + + def test_load_model_rejection_leaves_server_globals_unchanged(self, monkeypatch): + """Rejected live-manager replacement should behave like a no-op.""" + import vllm_mlx.server as srv + + live_engine = object() + live_manager = SimpleNamespace( + get_engine=lambda model_key: live_engine, + get_status=lambda model_key: { + "state": "loaded", + "active_requests": 0, + }, + ) + + monkeypatch.setattr(srv, "_residency_manager", live_manager, raising=False) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + monkeypatch.setattr(srv, "_model_name", "old-model", raising=False) + monkeypatch.setattr(srv, "_default_max_tokens", 128, raising=False) + monkeypatch.setattr(srv, "_force_mllm_model", False, raising=False) + monkeypatch.setattr(srv, "_auto_unload_idle_seconds", 60.0, raising=False) + monkeypatch.setattr(srv, "_tool_parser_instance", object(), raising=False) + + original_parser = srv._tool_parser_instance + + with pytest.raises(RuntimeError, match="existing residency manager"): + srv.load_model( + "new-model", + max_tokens=999, + force_mllm=True, + auto_unload_idle_seconds=120, + ) + + assert srv._residency_manager is live_manager + assert srv._model_name == "old-model" + assert srv._default_max_tokens == 128 + assert srv._force_mllm_model is False + assert srv._auto_unload_idle_seconds == 60.0 + assert srv._tool_parser_instance is original_parser + + def test_load_model_rejects_live_legacy_engine_when_enabling_lifecycle( + self, monkeypatch + ): + """Enabling lifecycle should reject while a live eager engine is present.""" + import vllm_mlx.server as srv + + class LiveEngine: + def __init__(self): + self.stopped = False + + async def stop(self): + self.stopped = True + + live_engine = LiveEngine() + + monkeypatch.setattr(srv, "_engine", live_engine, raising=False) + monkeypatch.setattr(srv, "_residency_manager", None, raising=False) + monkeypatch.setattr(srv, "_default_model_key", None, raising=False) + monkeypatch.setattr(srv, "_model_name", "old-model", raising=False) + monkeypatch.setattr(srv, "_default_max_tokens", 128, raising=False) + monkeypatch.setattr(srv, "_force_mllm_model", False, raising=False) + monkeypatch.setattr(srv, "_auto_unload_idle_seconds", 0.0, raising=False) + monkeypatch.setattr(srv, "_tool_parser_instance", object(), raising=False) + + original_parser = srv._tool_parser_instance + + with pytest.raises(RuntimeError, match="existing engine"): + srv.load_model( + "new-model", + max_tokens=999, + force_mllm=True, + auto_unload_idle_seconds=120, + ) + assert live_engine.stopped is False + assert srv._engine is live_engine + assert srv._residency_manager is None + assert srv._default_model_key is None + assert srv._model_name == "old-model" + assert srv._default_max_tokens == 128 + assert srv._force_mllm_model is False + assert srv._auto_unload_idle_seconds == 0.0 + assert srv._tool_parser_instance is original_parser + + +class TestLifecycleLoopIdleEvent: + """Verify that the lifecycle loop uses an asyncio.Event for gating.""" + + @pytest.mark.anyio + async def test_lifecycle_loop_blocks_when_event_cleared(self): + """The loop should block on the Event, not busy-poll.""" + import vllm_mlx.server as srv + + iterations = 0 + + class FakeManager: + def __init__(self): + self.auto_unload_idle_seconds = 10 + + def get_engine(self, key): + return None + + def get_status(self, key): + return {"state": "unloaded", "active_requests": 0} + + async def unload_if_idle(self, key): + nonlocal iterations + iterations += 1 + return False + + monkeypatch_attrs = { + "_residency_manager": FakeManager(), + "_default_model_key": "default", + "_auto_unload_idle_seconds": 10.0, + } + originals = {} + for k, v in monkeypatch_attrs.items(): + originals[k] = getattr(srv, k, None) + setattr(srv, k, v) + + # Clear the event so the loop should block + idle_event = srv._get_idle_unload_event() + idle_event.clear() + + task = asyncio.create_task(srv._lifecycle_loop()) + try: + # Give it time — if it were polling at 0.1s it would iterate many times + await asyncio.sleep(0.3) + assert ( + iterations == 0 + ), f"Loop iterated {iterations} times while event was cleared" + + # Now set the event and let it run one iteration + idle_event.set() + await asyncio.sleep(0.1) + assert iterations >= 1 + finally: + task.cancel() + with suppress(asyncio.CancelledError): + await task + for k, v in originals.items(): + setattr(srv, k, v) + idle_event.set() + + +class TestPublicLifecycleStatusSanitization: + """Verify _public_lifecycle_status sanitization behavior.""" + + def test_none_error_stays_none(self): + import vllm_mlx.server as srv + + result = srv._public_lifecycle_status( + { + "state": "loaded", + "last_error": None, + } + ) + assert result["last_error"] is None + + def test_raw_error_replaced_with_category(self): + import vllm_mlx.server as srv + + result = srv._public_lifecycle_status( + { + "state": "failed", + "last_error": "OSError: /tmp/model not found", + } + ) + assert result["last_error"] == "model_load_failed" + + def test_returns_none_for_none_input(self): + import vllm_mlx.server as srv + + assert srv._public_lifecycle_status(None) is None + + @pytest.mark.anyio + async def test_health_surfaces_sanitized_error_on_failed_resident( + self, monkeypatch + ): + """Failed resident should show last_error='model_load_failed' on /health.""" + import vllm_mlx.server as srv + + fake_manager = SimpleNamespace( + get_status=lambda model_key: { + "model_key": model_key, + "state": "failed", + "active_requests": 0, + "last_used_at": None, + "loaded_at": None, + "last_error": "RuntimeError: out of memory", + "auto_unload_idle_seconds": 60, + } + ) + monkeypatch.setattr(srv, "_engine", None) + monkeypatch.setattr(srv, "_model_name", "test-model", raising=False) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + monkeypatch.setattr(srv, "_residency_manager", fake_manager, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None) + + payload = await srv.health() + + assert payload["status"] == "unhealthy" + assert payload["residency_state"] == "failed" + assert payload["last_error"] == "model_load_failed" + + @pytest.mark.anyio + async def test_health_omits_last_error_for_healthy_resident(self, monkeypatch): + """Healthy/loaded resident should not include last_error in health.""" + import vllm_mlx.server as srv + + fake_manager = SimpleNamespace( + get_status=lambda model_key: { + "model_key": model_key, + "state": "loaded", + "active_requests": 1, + "last_used_at": 1710200000.0, + "loaded_at": 1710199000.0, + "last_error": None, + "auto_unload_idle_seconds": 0, + } + ) + + class FakeEngine: + is_mllm = False + + def get_stats(self): + return {"engine_type": "simple"} + + monkeypatch.setattr(srv, "_engine", FakeEngine()) + monkeypatch.setattr(srv, "_model_name", "test-model", raising=False) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + monkeypatch.setattr(srv, "_residency_manager", fake_manager, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None) + + payload = await srv.health() + + assert payload["status"] == "healthy" + assert "last_error" not in payload + + @pytest.mark.anyio + async def test_health_and_status_agree_on_empty_string_error(self, monkeypatch): + """An empty-string last_error should be treated the same by both + /health and /v1/status — both use ``is not None`` for consistency.""" + import vllm_mlx.server as srv + + fake_manager = SimpleNamespace( + get_status=lambda model_key: { + "model_key": model_key, + "model_name": "test-model", + "state": "failed", + "active_requests": 0, + "last_used_at": None, + "loaded_at": None, + "last_error": "", + "auto_unload_idle_seconds": 60, + } + ) + monkeypatch.setattr(srv, "_engine", None) + monkeypatch.setattr(srv, "_model_name", "test-model", raising=False) + monkeypatch.setattr(srv, "_default_model_key", "default", raising=False) + monkeypatch.setattr(srv, "_residency_manager", fake_manager, raising=False) + monkeypatch.setattr(srv, "_mcp_manager", None) + + health_payload = await srv.health() + status_payload = await srv.status() + + # Both should report the same sanitized error category + assert health_payload["last_error"] == "model_load_failed" + assert status_payload["residency"]["last_error"] == "model_load_failed" + + def test_public_lifecycle_status_empty_string_error(self): + """_public_lifecycle_status should treat '' the same as a real error.""" + import vllm_mlx.server as srv + + result = srv._public_lifecycle_status( + { + "state": "failed", + "last_error": "", + } + ) + assert result["last_error"] == "model_load_failed" + + +class TestResponseModelFieldUsesServedName: + """Verify response .model echoes _model_name (the served name), not + whatever the client sent in request.model. + + Each test monkeypatches _validate_model_name to a no-op so the request + can carry a distinct model string, proving the response field is sourced + from the server-side served name rather than the request echo-back. + """ + + @pytest.mark.anyio + async def test_completion_response_uses_served_model_name(self, monkeypatch): + import vllm_mlx.server as srv + from vllm_mlx.engine.base import GenerationOutput + + class FakeEngine: + preserve_native_tool_format = False + is_mllm = False + + async def generate(self, **kwargs): + return GenerationOutput( + text="hello", + completion_tokens=1, + prompt_tokens=1, + ) + + served_name = "my-custom-served-name" + + async def fake_acquire( + raw_request, *, total_timeout=None, deadline=None, count_activity=True + ): + return FakeEngine() + + async def fake_release(*, count_activity=True): + pass + + monkeypatch.setattr(srv, "_validate_model_name", lambda _m: None) + monkeypatch.setattr(srv, "_acquire_default_engine_for_request", fake_acquire) + monkeypatch.setattr(srv, "_release_default_engine", fake_release) + monkeypatch.setattr(srv, "_model_name", served_name) + monkeypatch.setattr(srv, "_default_max_tokens", 32) + + class FakeRawRequest: + async def is_disconnected(self): + return False + + request = srv.CompletionRequest( + model="user-sent-model-name", + prompt="hi", + stream=False, + ) + + response = await srv.create_completion(request, FakeRawRequest()) + assert response.model == served_name + assert response.model != "user-sent-model-name" + + @pytest.mark.anyio + async def test_chat_completion_response_uses_served_model_name(self, monkeypatch): + import vllm_mlx.server as srv + from vllm_mlx.engine.base import GenerationOutput + + class FakeEngine: + preserve_native_tool_format = False + is_mllm = False + tokenizer = None + + async def chat(self, **kwargs): + return GenerationOutput( + text="hello", + completion_tokens=1, + prompt_tokens=1, + ) + + served_name = "my-custom-served-name" + + async def fake_acquire( + raw_request, *, total_timeout=None, deadline=None, count_activity=True + ): + return FakeEngine() + + async def fake_release(*, count_activity=True): + pass + + monkeypatch.setattr(srv, "_validate_model_name", lambda _m: None) + monkeypatch.setattr(srv, "_acquire_default_engine_for_request", fake_acquire) + monkeypatch.setattr(srv, "_release_default_engine", fake_release) + monkeypatch.setattr(srv, "_model_name", served_name) + monkeypatch.setattr(srv, "_default_max_tokens", 32) + monkeypatch.setattr(srv, "_enable_auto_tool_choice", False) + monkeypatch.setattr(srv, "_reasoning_parser", None) + + class FakeRawRequest: + async def is_disconnected(self): + return False + + request = srv.ChatCompletionRequest( + model="user-sent-model-name", + messages=[{"role": "user", "content": "hi"}], + stream=False, + ) + + response = await srv.create_chat_completion(request, FakeRawRequest()) + assert response.model == served_name + assert response.model != "user-sent-model-name" + + @pytest.mark.anyio + async def test_anthropic_response_uses_served_model_name(self, monkeypatch): + import json + + import vllm_mlx.server as srv + from vllm_mlx.engine.base import GenerationOutput + + class FakeEngine: + preserve_native_tool_format = False + is_mllm = False + tokenizer = None + + async def chat(self, **kwargs): + return GenerationOutput( + text="hello", + completion_tokens=1, + prompt_tokens=1, + ) + + served_name = "my-custom-served-name" + + async def fake_acquire( + raw_request, *, total_timeout=None, deadline=None, count_activity=True + ): + return FakeEngine() + + async def fake_release(*, count_activity=True): + pass + + monkeypatch.setattr(srv, "_validate_model_name", lambda _m: None) + monkeypatch.setattr(srv, "_acquire_default_engine_for_request", fake_acquire) + monkeypatch.setattr(srv, "_release_default_engine", fake_release) + monkeypatch.setattr(srv, "_model_name", served_name) + monkeypatch.setattr(srv, "_default_max_tokens", 32) + monkeypatch.setattr(srv, "_enable_auto_tool_choice", False) + monkeypatch.setattr(srv, "_reasoning_parser", None) + + class FakeRawRequest: + async def json(self): + return { + "model": "user-sent-model-name", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 16, + "stream": False, + } + + async def is_disconnected(self): + return False + + response = await srv.create_anthropic_message(FakeRawRequest()) + body = json.loads(response.body.decode()) + assert body["model"] == served_name + assert body["model"] != "user-sent-model-name" diff --git a/tests/test_llm.py b/tests/test_llm.py index bd78ebc6a..701cfd9b2 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -3,6 +3,8 @@ import platform import sys +from types import SimpleNamespace +from unittest.mock import MagicMock, patch import pytest @@ -51,6 +53,31 @@ def test_model_repr(): assert "not loaded" in repr_str +def test_model_stream_generate_passes_num_draft_tokens(): + """Native MTP path should forward configured draft depth to mlx_lm.""" + from vllm_mlx.models.llm import MLXLanguageModel + + model = MLXLanguageModel("test-model", mtp=True, mtp_num_draft_tokens=4) + model._loaded = True + model.model = object() + tokenizer = MagicMock() + tokenizer.encode.return_value = [1, 2, 3] + model.tokenizer = tokenizer + + captured_kwargs = {} + + def fake_stream_generate(_model, _tokenizer, **kwargs): + captured_kwargs.update(kwargs) + yield SimpleNamespace(text="Hello") + + with patch("mlx_lm.stream_generate", side_effect=fake_stream_generate): + chunks = list(model.stream_generate("Hello", max_tokens=8)) + + assert chunks[-1].text == "Hello" + assert captured_kwargs["mtp"] is True + assert captured_kwargs["num_draft_tokens"] == 4 + + @pytest.mark.slow def test_model_load(small_model_name): """Test loading a model (slow test, downloads model).""" @@ -99,6 +126,58 @@ def test_model_stream_generate(small_model_name): assert any(chunk.finished for chunk in chunks) +@pytest.mark.slow +def test_model_stream_generate_with_prompt_cache(small_model_name): + """Test streaming generation with pre-populated prompt_cache.""" + pytest.importorskip("mlx_lm") + import mlx.core as mx + from mlx_lm.models.cache import make_prompt_cache + + from vllm_mlx.models.llm import MLXLanguageModel + + model = MLXLanguageModel(small_model_name) + model.load() + + # Pre-populate cache by running a prefill + cache = make_prompt_cache(model.model) + tokens = model.tokenizer.encode("Hello") + model.model(mx.array([tokens]), cache=cache) + mx.eval([c.state for c in cache]) + + # Generate from a single token with the pre-populated cache + prompt_token = mx.array([tokens[-1]]) + chunks = list( + model.stream_generate( + prompt=prompt_token, + max_tokens=10, + prompt_cache=cache, + ) + ) + + assert len(chunks) > 0 + assert any(chunk.finished for chunk in chunks) + # prompt_tokens should reflect the single-token prompt, not the full string + assert chunks[0].prompt_tokens == 1 + + +@pytest.mark.slow +def test_model_stream_generate_with_list_prompt(small_model_name): + """Test streaming generation with list[int] prompt.""" + pytest.importorskip("mlx_lm") + + from vllm_mlx.models.llm import MLXLanguageModel + + model = MLXLanguageModel(small_model_name) + model.load() + + token_ids = model.tokenizer.encode("Hello") + chunks = list(model.stream_generate(prompt=token_ids, max_tokens=10)) + + assert len(chunks) > 0 + assert any(chunk.finished for chunk in chunks) + assert chunks[0].prompt_tokens == len(token_ids) + + @pytest.mark.slow def test_model_chat(small_model_name): """Test chat interface.""" diff --git a/tests/test_mcp_security.py b/tests/test_mcp_security.py index eecc148f0..b68b5d00a 100644 --- a/tests/test_mcp_security.py +++ b/tests/test_mcp_security.py @@ -98,7 +98,16 @@ def test_path_traversal_blocked(self): with pytest.raises(MCPSecurityError) as exc_info: validator.validate_command("../../../bin/bash", "test-server") - assert "dangerous pattern" in str(exc_info.value) + assert "path traversal" in str(exc_info.value) + + def test_command_newline_blocked(self): + """Test that newline separators in commands are rejected.""" + validator = MCPCommandValidator(check_path_exists=False) + + with pytest.raises(MCPSecurityError) as exc_info: + validator.validate_command("npx\ncat /etc/passwd", "test-server") + + assert "newline characters" in str(exc_info.value) class TestArgumentValidation: @@ -142,6 +151,72 @@ def test_dollar_expansion_in_args_blocked(self): assert "dangerous pattern" in str(exc_info.value) + def test_python_inline_code_flag_blocked(self): + """Test that python -c is rejected even though python is whitelisted.""" + validator = MCPCommandValidator(check_path_exists=False) + + with pytest.raises(MCPSecurityError) as exc_info: + validator.validate_command_args("python3", ["-c", "print('owned')"], "test") + + assert "inline Python execution" in str(exc_info.value) + + def test_node_eval_flag_blocked(self): + """Test that node --eval is rejected.""" + validator = MCPCommandValidator(check_path_exists=False) + + with pytest.raises(MCPSecurityError) as exc_info: + validator.validate_command_args( + "node", + ["--eval=console.log('owned')"], + "test", + ) + + assert "inline JavaScript evaluation" in str(exc_info.value) + + def test_npx_call_flag_blocked(self): + """Test that npx -c shell execution is rejected.""" + validator = MCPCommandValidator(check_path_exists=False) + + with pytest.raises(MCPSecurityError) as exc_info: + validator.validate_command_args("npx", ["-c", "echo owned"], "test") + + assert "shell command execution" in str(exc_info.value) + + def test_interpreter_normal_launch_args_still_pass(self): + """Test that standard MCP launches are unaffected.""" + validator = MCPCommandValidator(check_path_exists=False) + + validator.validate_command_args( + "npx", + ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], + "filesystem", + ) + validator.validate_command_args( + "uvx", + ["mcp-server-sqlite", "--db-path", "data.db"], + "sqlite", + ) + + def test_newline_in_args_blocked(self): + """Test that newline command separators in args are blocked.""" + validator = MCPCommandValidator(check_path_exists=False) + + with pytest.raises(MCPSecurityError) as exc_info: + validator.validate_args(["safe", "line1\nline2"], "test-server") + + assert "newline characters" in str(exc_info.value) + + def test_url_encoded_path_traversal_in_args_blocked(self): + """Test that encoded traversal in args is rejected.""" + validator = MCPCommandValidator(check_path_exists=False) + + with pytest.raises(MCPSecurityError) as exc_info: + validator.validate_args( + ["--path", "%2e%2e/%2e%2e/etc/passwd"], "test-server" + ) + + assert "path traversal" in str(exc_info.value) + class TestEnvironmentValidation: """Tests for environment variable validation.""" @@ -189,6 +264,15 @@ def test_injection_in_env_value_blocked(self): assert "dangerous pattern" in str(exc_info.value) + def test_newline_in_env_value_blocked(self): + """Test that newlines in environment values are rejected.""" + validator = MCPCommandValidator(check_path_exists=False) + + with pytest.raises(MCPSecurityError) as exc_info: + validator.validate_env({"SAFE_VAR": "line1\rline2"}, "test-server") + + assert "newline characters" in str(exc_info.value) + class TestURLValidation: """Tests for SSE URL validation.""" @@ -236,6 +320,29 @@ def test_injection_in_url_blocked(self): assert "dangerous pattern" in str(exc_info.value) + def test_newline_in_url_blocked(self): + """Test that newlines in URLs are rejected.""" + validator = MCPCommandValidator(check_path_exists=False) + + with pytest.raises(MCPSecurityError) as exc_info: + validator.validate_url( + "https://example.com/sse\ncurl attacker", "test-server" + ) + + assert "newline characters" in str(exc_info.value) + + def test_url_encoded_path_traversal_in_url_blocked(self): + """Test that encoded traversal in URL paths is rejected.""" + validator = MCPCommandValidator(check_path_exists=False) + + with pytest.raises(MCPSecurityError) as exc_info: + validator.validate_url( + "https://example.com/%2e%2e/%2e%2e/private", + "test-server", + ) + + assert "path traversal" in str(exc_info.value) + class TestUnsafeMode: """Tests for unsafe mode (development only).""" @@ -322,6 +429,30 @@ def test_invalid_command_rejected(self): assert "not in the allowed commands whitelist" in str(exc_info.value) + def test_inline_python_execution_in_config_rejected(self): + """Test that interpreter eval forms are rejected in config.""" + with pytest.raises(ValueError) as exc_info: + MCPServerConfig( + name="inline-python", + transport=MCPTransport.STDIO, + command="python3", + args=["-c", "print('owned')"], + ) + + assert "inline Python execution" in str(exc_info.value) + + def test_node_eval_in_config_rejected(self): + """Test that node eval forms are rejected in config.""" + with pytest.raises(ValueError) as exc_info: + MCPServerConfig( + name="inline-node", + transport=MCPTransport.STDIO, + command="node", + args=["--eval=console.log('owned')"], + ) + + assert "inline JavaScript evaluation" in str(exc_info.value) + def test_command_injection_in_config_rejected(self): """Test that command injection in config is rejected.""" with pytest.raises(ValueError) as exc_info: @@ -333,6 +464,18 @@ def test_command_injection_in_config_rejected(self): assert "dangerous pattern" in str(exc_info.value) + def test_encoded_path_traversal_in_config_rejected(self): + """Test that encoded traversal in config args is rejected.""" + with pytest.raises(ValueError) as exc_info: + MCPServerConfig( + name="encoded-traversal", + transport=MCPTransport.STDIO, + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem", "%2e%2e/%2e%2e"], + ) + + assert "path traversal" in str(exc_info.value) + def test_valid_sse_config(self): """Test that valid SSE config passes validation.""" config = MCPServerConfig( @@ -342,6 +485,34 @@ def test_valid_sse_config(self): ) assert config.url == "https://api.example.com/mcp" + def test_config_parses_allowed_high_risk_tools(self): + """Root MCP config should accept explicit high-risk tool allowlists.""" + from vllm_mlx.mcp.config import validate_config + + config = validate_config( + { + "servers": {}, + "allowed_high_risk_tools": [ + "trusted__execute_command", + "run_shell", + ], + } + ) + + assert config.allowed_high_risk_tools == { + "trusted__execute_command", + "run_shell", + } + + def test_config_rejects_invalid_allowed_high_risk_tools(self): + """High-risk allowlists must be a list of non-empty strings.""" + from vllm_mlx.mcp.config import validate_config + + with pytest.raises(ValueError) as exc_info: + validate_config({"servers": {}, "allowed_high_risk_tools": ["ok", ""]}) + + assert "allowed_high_risk_tools" in str(exc_info.value) + def test_skip_security_validation_field_rejected(self): """Test that config-file security bypass is rejected explicitly.""" with pytest.raises(ValueError) as exc_info: @@ -802,27 +973,25 @@ def test_clear_audit_log(self): class TestToolSandboxHighRiskTools: """Tests for high-risk tool detection.""" - def test_high_risk_tool_warning(self, caplog): - """Test that high-risk tools trigger warning.""" - import logging - + def test_high_risk_tool_blocked_by_default(self): + """High-risk tools should be blocked unless explicitly allowlisted.""" sandbox = ToolSandbox() - with caplog.at_level(logging.WARNING): + with pytest.raises(MCPSecurityError) as exc_info: sandbox.validate_tool_execution( tool_name="execute_command", server_name="test", arguments={"cmd": "ls"}, ) - assert "High-risk tool detected" in caplog.text - assert "execute" in caplog.text + assert "High-risk tool 'execute_command' is blocked" in str(exc_info.value) + assert "allowed_high_risk_tools" in str(exc_info.value) - def test_high_risk_shell_tool(self, caplog): - """Test that shell tools trigger warning.""" + def test_high_risk_tool_allowed_by_short_name(self, caplog): + """Short-name allowlist entries should permit trusted high-risk tools.""" import logging - sandbox = ToolSandbox() + sandbox = ToolSandbox(allowed_high_risk_tools={"run_shell"}) with caplog.at_level(logging.WARNING): sandbox.validate_tool_execution( @@ -831,7 +1000,22 @@ def test_high_risk_shell_tool(self, caplog): arguments={}, ) - assert "High-risk tool detected" in caplog.text + assert "Allowing high-risk tool 'test__run_shell'" in caplog.text + + def test_high_risk_tool_allowed_by_full_name(self, caplog): + """Full-name allowlist entries should permit trusted high-risk tools.""" + import logging + + sandbox = ToolSandbox(allowed_high_risk_tools={"trusted__execute_command"}) + + with caplog.at_level(logging.WARNING): + sandbox.validate_tool_execution( + tool_name="execute_command", + server_name="trusted", + arguments={"cmd": "ls"}, + ) + + assert "Allowing high-risk tool 'trusted__execute_command'" in caplog.text class TestCustomBlockedPatterns: diff --git a/tests/test_memory_cache_mlx.py b/tests/test_memory_cache_mlx.py new file mode 100644 index 000000000..b39395485 --- /dev/null +++ b/tests/test_memory_cache_mlx.py @@ -0,0 +1,452 @@ +# SPDX-License-Identifier: Apache-2.0 +"""MLX-dependent regression tests for the LCP trim contamination fix (#384). + +These tests use real ``mlx_lm.models.cache.KVCache`` / ``RotatingKVCache`` +objects backed by ``mlx.core`` arrays. They run only on the apple-silicon +CI matrix (and locally on M-series hardware); the Linux ``test-matrix`` +job excludes this file because MLX has no Linux distribution. +""" + +from unittest.mock import MagicMock + + +class TestTrimCacheOffset: + """Tests for ``_trim_cache_offset``, focused on the LCP contamination fix. + + Regression: when the LCP fetch path trimmed a cache entry by shrinking + the offset while still sharing the underlying (oversized) key/value + arrays, downstream attention layers that read ``cache.state`` directly + (e.g. Gemma 4 KV-shared layers) could see stale tokens from the previous + owner of the entry. See issue #384. The fix slices the arrays down to + new_offset so no memory beyond the new boundary remains accessible. + """ + + def test_plain_kv_cache_array_sliced_to_new_offset(self): + """Plain-KVCache-like layer: after trim, keys.shape[-2] == new_offset.""" + import mlx.core as mx + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import _trim_cache_offset + + layer = KVCache() + # Pretend a previous request wrote 500 tokens worth of data + layer.keys = mx.arange(1 * 4 * 500 * 8, dtype=mx.float32).reshape(1, 4, 500, 8) + layer.values = mx.arange(1 * 4 * 500 * 8, dtype=mx.float32).reshape( + 1, 4, 500, 8 + ) + layer.offset = 500 + + # New request shares only the first 60 tokens as prefix + trim_by = 500 - 60 + trimmed = _trim_cache_offset([layer], trim_by) + tc = trimmed[0] + + assert tc.offset == 60 + # The underlying array MUST be shrunk, not just the offset pointer. + # Otherwise Gemma 4's cache.state-reading layers would see positions + # 60..500 filled with the previous request's tokens. + assert tc.keys.shape[-2] == 60 + assert tc.values.shape[-2] == 60 + + def test_plain_kv_cache_no_stale_tokens_visible_via_state(self): + """A layer that reads the full cache.state must not see tokens past new_offset.""" + import mlx.core as mx + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import _trim_cache_offset + + layer = KVCache() + # Positions 0..60: shared prefix (same for everyone). Positions 60..500: + # private content from a previous request that must NOT leak. + shared = mx.ones((1, 4, 60, 8), dtype=mx.float32) + private = mx.full((1, 4, 440, 8), 7.0, dtype=mx.float32) + layer.keys = mx.concatenate([shared, private], axis=2) + layer.values = layer.keys + layer.offset = 500 + + tc = _trim_cache_offset([layer], 500 - 60)[0] + + # cache.state is what KV-shared layers read directly. + keys_view, _ = tc.state + assert keys_view.shape[-2] == 60 + # No "7.0" tokens anywhere — private content was excluded. + assert float(mx.max(keys_view).item()) == 1.0 + + def test_plain_kv_cache_no_trim_preserves_array(self): + """If trim_by == 0 or offset already equals shape, array is untouched.""" + import mlx.core as mx + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import _trim_cache_offset + + layer = KVCache() + layer.keys = mx.ones((1, 4, 100, 8), dtype=mx.float32) + layer.values = mx.ones((1, 4, 100, 8), dtype=mx.float32) + layer.offset = 100 + + tc = _trim_cache_offset([layer], 0)[0] + + assert tc.offset == 100 + assert tc.keys.shape[-2] == 100 + + def test_plain_kv_cache_trim_by_exceeds_offset_clamps_to_zero(self): + """trim_by larger than offset yields an empty-but-valid trimmed cache.""" + import mlx.core as mx + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import _trim_cache_offset + + layer = KVCache() + layer.keys = mx.ones((1, 4, 80, 8), dtype=mx.float32) + layer.values = mx.ones((1, 4, 80, 8), dtype=mx.float32) + layer.offset = 80 + + tc = _trim_cache_offset([layer], 1000)[0] + + assert tc.offset == 0 + assert tc.keys.shape[-2] == 0 + assert tc.values.shape[-2] == 0 + + def test_plain_kv_cache_stored_entry_unaffected_after_trim(self): + """Calling _trim_cache_offset must not mutate the source layer in place. + + The stored prefix-cache entry is the source here; a later lookup for + a different request should get the same pristine data. + """ + import mlx.core as mx + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import _trim_cache_offset + + layer = KVCache() + full = mx.arange(1 * 2 * 200 * 4, dtype=mx.float32).reshape(1, 2, 200, 4) + layer.keys = full + layer.values = full + layer.offset = 200 + original_shape = layer.keys.shape + + _trim_cache_offset([layer], 150) + + # Source entry keeps its full shape and offset. + assert layer.keys.shape == original_shape + assert layer.values.shape == original_shape + assert layer.offset == 200 + + def test_plain_kv_cache_in_place_write_does_not_corrupt_source(self): + """After trim, writing through the returned cache must not leak into + the stored entry. This is the direct semantics of the fix: the stored + prefix-cache entry has to survive concurrent use by other requests. + """ + import mlx.core as mx + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import _trim_cache_offset + + # Source stored entry: positions 0..300 holding 5.0. + layer = KVCache() + layer.keys = mx.full((1, 2, 300, 4), 5.0, dtype=mx.float32) + layer.values = mx.full((1, 2, 300, 4), 5.0, dtype=mx.float32) + layer.offset = 300 + + # New request shares first 50 tokens. + tc = _trim_cache_offset([layer], 300 - 50)[0] + + # The trimmed cache now only has 50 tokens. Writing new tokens via + # update_and_fetch allocates a new array (because prev + N > current + # shape) and does not touch the source. + new_keys = mx.zeros((1, 2, 10, 4), dtype=mx.float32) + new_values = mx.zeros((1, 2, 10, 4), dtype=mx.float32) + tc.update_and_fetch(new_keys, new_values) + + # Source remains untouched (all 5.0 values preserved across full range). + assert layer.keys.shape[-2] == 300 + assert float(mx.min(layer.keys).item()) == 5.0 + assert float(mx.max(layer.keys).item()) == 5.0 + + def test_plain_kv_cache_multiple_layers_all_sliced(self): + """Caches with several KVCache layers: every layer gets sliced.""" + import mlx.core as mx + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import _trim_cache_offset + + layers = [] + for _ in range(5): + layer = KVCache() + layer.keys = mx.ones((1, 4, 200, 8), dtype=mx.float32) + layer.values = mx.ones((1, 4, 200, 8), dtype=mx.float32) + layer.offset = 200 + layers.append(layer) + + trimmed = _trim_cache_offset(layers, 150) + + assert len(trimmed) == 5 + for tc in trimmed: + assert tc.offset == 50 + assert tc.keys.shape[-2] == 50 + assert tc.values.shape[-2] == 50 + + def test_plain_kv_cache_slice_works_for_float16_and_bfloat16(self): + """Fix must be dtype-agnostic so quantized / mixed-precision KV caches + receive the same treatment as fp32. + """ + import mlx.core as mx + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import _trim_cache_offset + + for dtype in (mx.float16, mx.bfloat16): + layer = KVCache() + layer.keys = mx.ones((1, 2, 120, 4), dtype=dtype) + layer.values = mx.ones((1, 2, 120, 4), dtype=dtype) + layer.offset = 120 + + tc = _trim_cache_offset([layer], 80)[0] + + assert tc.offset == 40, f"dtype={dtype}" + assert tc.keys.shape[-2] == 40, f"dtype={dtype}" + assert tc.keys.dtype == dtype, f"dtype={dtype}" + + def test_plain_kv_cache_rotating_layers_unchanged_behavior(self): + """RotatingKVCache was already trimming correctly before this fix. + The plain-KVCache branch is the only one that changed; the rotating + branch is exercised here to catch regressions. + """ + import mlx.core as mx + from mlx_lm.models.cache import RotatingKVCache + + from vllm_mlx.memory_cache import _trim_cache_offset + + layer = RotatingKVCache(max_size=128, keep=0) + # Layer already rotated once: offset=200, buffer holds max_size entries. + layer.keys = mx.ones((1, 4, 128, 8), dtype=mx.float32) + layer.values = mx.ones((1, 4, 128, 8), dtype=mx.float32) + layer.offset = 200 + layer._idx = 128 + + tc = _trim_cache_offset([layer], 100)[0] + + # Offset dropped by trim_by, clamped at >= 0. + assert tc.offset == 100 + # Rotating path materialises a buffer whose shape matches new_offset + # (padding with zeros if needed). It must not come back as None. + assert tc.keys is not None + assert tc.values is not None + # Dtype preserved through trim. + assert tc.keys.dtype == mx.float32 + # Type-specific attrs preserved. + assert hasattr(tc, "max_size") + assert tc.max_size == 128 + + def test_fetch_returns_sliced_cache_on_lcp_match(self): + """End-to-end: MemoryAwarePrefixCache.fetch on a request that shares + only a prefix with a longer stored entry must return a cache whose + arrays are already sliced down. This is the full regression of the + #384 scenario above the unit level. + """ + import mlx.core as mx + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig + + model = MagicMock() + cache = MemoryAwarePrefixCache( + model, MemoryCacheConfig(max_memory_mb=64, max_entries=10) + ) + + # Stored: tokens [1..120] with 120 positions of KV data, the first 60 + # tokens being the shared prefix (all 1.0), the last 60 private (7.0). + stored_layer = KVCache() + shared = mx.ones((1, 2, 60, 4), dtype=mx.float32) + private = mx.full((1, 2, 60, 4), 7.0, dtype=mx.float32) + stored_layer.keys = mx.concatenate([shared, private], axis=2) + stored_layer.values = stored_layer.keys + stored_layer.offset = 120 + cache.store(list(range(1, 121)), [stored_layer]) + + # New request: tokens [1..59] + [999, 1000, 1001] — first 59 tokens + # match, then diverge. LCP is 59. + new_tokens = list(range(1, 60)) + [999, 1000, 1001] + fetched, remaining = cache.fetch(new_tokens) + + assert fetched is not None + tc = fetched[0] + # LCP of 59 (the divergent tokens are stripped). + assert tc.offset == 59 + assert tc.keys.shape[-2] == 59 + # Critical: the "7.0" private content from the stored entry must NOT + # be visible anywhere in the returned cache (this is what caused the + # cross-request contamination in #384). + assert float(mx.max(tc.keys).item()) == 1.0 + assert remaining == [999, 1000, 1001] + + +class TestDequantizeCacheSlice: + """Tests for _dequantize_cache slicing after dequantization. + + When KV cache quantization is enabled (--kv-cache-quantization), the + prefix cache stores _QuantizedCacheWrapper layers. After LCP trim + reduces the offset, _dequantize_cache must slice the dequantized arrays + down to offset to prevent readers that bypass offset (e.g. Gemma 4's + KV-shared layers reading cache.state) from seeing stale tokens. + + This is the quantized-cache counterpart of the plain-KVCache fix + tested in TestTrimCacheOffset above. + """ + + def test_dequantize_slices_to_offset(self): + """After trim + dequantize, keys/values shape[-2] == offset.""" + import mlx.core as mx + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import ( + _QuantizedCacheWrapper, + _dequantize_cache, + _trim_cache_offset, + ) + + # Build a KVCache with 500 tokens, quantize it, then trim to 60. + layer = KVCache() + layer.keys = mx.ones((1, 4, 512, 64), dtype=mx.float32) + layer.values = mx.ones((1, 4, 512, 64), dtype=mx.float32) + layer.offset = 512 + mx.eval(layer.keys, layer.values) + + qw = _QuantizedCacheWrapper(layer, bits=8, group_size=64) + trimmed = _trim_cache_offset([qw], 512 - 60) + result = _dequantize_cache(trimmed) + + tc = result[0] + assert tc.offset == 60 + assert tc.keys.shape[-2] == 60 + assert tc.values.shape[-2] == 60 + + def test_dequantize_no_stale_tokens_via_state(self): + """Stale tokens from a previous request must not be visible via cache.state.""" + import mlx.core as mx + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import ( + _QuantizedCacheWrapper, + _dequantize_cache, + _trim_cache_offset, + ) + + layer = KVCache() + # First 64 positions: shared prefix (1.0), next 448: private (7.0) + shared = mx.ones((1, 4, 64, 64), dtype=mx.float32) + private = mx.full((1, 4, 448, 64), 7.0, dtype=mx.float32) + layer.keys = mx.concatenate([shared, private], axis=2) + layer.values = mx.concatenate([shared, private], axis=2) + layer.offset = 512 + mx.eval(layer.keys, layer.values) + + qw = _QuantizedCacheWrapper(layer, bits=8, group_size=64) + trimmed = _trim_cache_offset([qw], 512 - 64) + result = _dequantize_cache(trimmed) + + tc = result[0] + keys_view, _ = tc.state + assert keys_view.shape[-2] == 64 + # Dequantized values are approximate (quantization error), but should + # be close to 1.0 (the shared prefix), never near 7.0 (the private data). + assert float(mx.max(keys_view).item()) < 2.0 + + def test_dequantize_no_trim_preserves_full_array(self): + """When offset == shape[-2], no slicing occurs.""" + import mlx.core as mx + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import ( + _QuantizedCacheWrapper, + _dequantize_cache, + ) + + layer = KVCache() + layer.keys = mx.ones((1, 4, 128, 64), dtype=mx.float32) + layer.values = mx.ones((1, 4, 128, 64), dtype=mx.float32) + layer.offset = 128 + mx.eval(layer.keys, layer.values) + + qw = _QuantizedCacheWrapper(layer, bits=8, group_size=64) + result = _dequantize_cache([qw]) + + tc = result[0] + assert tc.offset == 128 + assert tc.keys.shape[-2] == 128 + assert tc.values.shape[-2] == 128 + + def test_dequantize_source_unaffected(self): + """Dequantizing must not mutate the stored quantized wrapper.""" + import mlx.core as mx + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import ( + _QuantizedCacheWrapper, + _dequantize_cache, + _trim_cache_offset, + ) + + layer = KVCache() + layer.keys = mx.ones((1, 4, 256, 64), dtype=mx.float32) + layer.values = mx.ones((1, 4, 256, 64), dtype=mx.float32) + layer.offset = 256 + mx.eval(layer.keys, layer.values) + + qw = _QuantizedCacheWrapper(layer, bits=8, group_size=64) + original_offset = qw.offset + original_keys_shape = qw.keys[0].shape # quantized data tuple + + trimmed = _trim_cache_offset([qw], 192) + _dequantize_cache(trimmed) + + # Source wrapper unchanged + assert qw.offset == original_offset + assert qw.keys[0].shape == original_keys_shape + + def test_dequantize_end_to_end_fetch_with_quantization(self): + """End-to-end: store with kv_quantize=True, fetch with LCP, verify no stale data.""" + import mlx.core as mx + from mlx_lm.models.cache import KVCache + + from vllm_mlx.memory_cache import ( + MemoryAwarePrefixCache, + MemoryCacheConfig, + ) + + model = MagicMock() + pc = MemoryAwarePrefixCache( + model, + MemoryCacheConfig( + max_memory_mb=64, + max_entries=10, + kv_quantize=True, + kv_bits=8, + kv_group_size=64, + kv_min_quantize_tokens=0, + ), + ) + + # Store a KVCache with 128 tokens — store() quantizes automatically. + layer = KVCache() + shared = mx.ones((1, 2, 64, 64), dtype=mx.float32) + private = mx.full((1, 2, 64, 64), 7.0, dtype=mx.float32) + layer.keys = mx.concatenate([shared, private], axis=2) + layer.values = mx.concatenate([shared, private], axis=2) + layer.offset = 128 + mx.eval(layer.keys, layer.values) + + pc.store(list(range(1, 129)), [layer]) + + # Fetch with partial match (first 60 tokens match, then diverge). + # fetch() dequantizes automatically when kv_quantize=True. + new_tokens = list(range(1, 61)) + [999, 1000] + fetched, remaining = pc.fetch(new_tokens) + + assert fetched is not None + tc = fetched[0] + assert tc.offset == 60 + assert tc.keys.shape[-2] == 60 + # No private data (7.0) visible — only shared prefix (~1.0 with quantization noise) + assert float(mx.max(tc.keys).item()) < 2.0 + assert remaining == [999, 1000] diff --git a/tests/test_mllm.py b/tests/test_mllm.py index ea9ee6593..049b840a3 100644 --- a/tests/test_mllm.py +++ b/tests/test_mllm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for MLX Multimodal Language Model (MLLM) wrapper.""" +import base64 import platform import sys from pathlib import Path @@ -123,6 +124,67 @@ def test_is_url(self): assert not is_url("/path/to/file.jpg") assert not is_url("data:image/png;base64,AAAA") + @pytest.mark.parametrize( + "url", + [ + "file:///etc/passwd", + "http://127.0.0.1/private.jpg", + "http://169.254.169.254/latest/meta-data/", + "http://localhost/private.jpg", + "http://[::1]/private.jpg", + ], + ) + def test_validate_url_safety_rejects_unsafe_targets(self, url): + """Test that local and private targets are rejected.""" + from vllm_mlx.models.mllm import UnsafeRemoteURLError, _validate_url_safety + + with pytest.raises(UnsafeRemoteURLError): + _validate_url_safety(url) + + def test_validate_url_safety_allows_public_ip(self): + """Test that public literal IPs remain allowed.""" + from vllm_mlx.models.mllm import _validate_url_safety + + _validate_url_safety("https://8.8.8.8/image.jpg") + + def test_request_with_safe_redirects_blocks_unsafe_redirect(self, monkeypatch): + """Test that redirect hops are validated before a second request.""" + from vllm_mlx.models import mllm + + calls = [] + + class FakeResponse: + def __init__(self, url, status_code, headers): + self.url = url + self.status_code = status_code + self.headers = headers + self.is_redirect = status_code in {301, 302, 303, 307, 308} + self.is_permanent_redirect = status_code in {301, 308} + + def close(self): + return None + + def fake_request(method, url, **kwargs): + calls.append((method, url, kwargs["allow_redirects"])) + return FakeResponse( + url, + 302, + {"location": "http://127.0.0.1/internal.jpg"}, + ) + + monkeypatch.setattr(mllm.requests, "request", fake_request) + + with pytest.raises(mllm.UnsafeRemoteURLError): + mllm._request_with_safe_redirects( + "GET", + "https://8.8.8.8/image.jpg", + timeout=30, + headers={"User-Agent": "test"}, + stream=True, + ) + + assert calls == [("GET", "https://8.8.8.8/image.jpg", False)] + class TestVideoFrameExtraction: """Test video frame extraction functions.""" @@ -184,38 +246,50 @@ def test_save_frames_to_temp(self, test_video_path): class TestImageProcessing: """Test image processing functions.""" - def test_process_image_input_local_file(self, test_image_path): - """Test processing local image file.""" + def test_process_image_input_local_file_rejected(self, test_image_path): + """Test that local image paths are rejected.""" from vllm_mlx.models.mllm import process_image_input - result = process_image_input(test_image_path) - assert result == test_image_path + with pytest.raises(ValueError, match="Unsupported image input"): + process_image_input(test_image_path) - def test_process_image_input_dict_format(self, test_image_path): - """Test processing image in dict format.""" + def test_process_image_input_dict_format_base64(self): + """Test processing image in dict format with base64 payload.""" from vllm_mlx.models.mllm import process_image_input - # OpenAI format - result = process_image_input({"url": test_image_path}) + image_b64 = base64.b64encode(b"\x89PNG\r\n\x1a\n\x00\x00\x00\x0dIHDR").decode() + result = process_image_input({"url": f"data:image/png;base64,{image_b64}"}) assert Path(result).exists() + def test_download_image_blocks_unsafe_url_before_request(self, monkeypatch): + """Test that blocked image URLs fail before any request is made.""" + from vllm_mlx.models import mllm + + def fail_request(*args, **kwargs): + raise AssertionError("requests.request should not be called") + + monkeypatch.setattr(mllm.requests, "request", fail_request) + + with pytest.raises(mllm.UnsafeRemoteURLError): + mllm.download_image("http://169.254.169.254/latest/meta-data/") + class TestVideoProcessing: """Test video processing functions.""" - def test_process_video_input_local_file(self, test_video_path): - """Test processing local video file.""" + def test_process_video_input_local_file_rejected(self, test_video_path): + """Test that local video paths are rejected.""" from vllm_mlx.models.mllm import process_video_input - result = process_video_input(test_video_path) - assert result == test_video_path + with pytest.raises(ValueError, match="Unsupported video input"): + process_video_input(test_video_path) - def test_process_video_input_dict_format(self, test_video_path): - """Test processing video in dict format.""" + def test_process_video_input_dict_format_base64(self): + """Test processing video in dict format with base64 payload.""" from vllm_mlx.models.mllm import process_video_input - # OpenAI format - result = process_video_input({"url": test_video_path}) + video_b64 = base64.b64encode(b"\x00" * 100).decode() + result = process_video_input({"url": f"data:video/mp4;base64,{video_b64}"}) assert Path(result).exists() def test_process_video_input_empty_raises(self): @@ -228,6 +302,18 @@ def test_process_video_input_empty_raises(self): with pytest.raises(ValueError): process_video_input({}) + def test_download_video_blocks_unsafe_url_before_request(self, monkeypatch): + """Test that blocked video URLs fail before any request is made.""" + from vllm_mlx.models import mllm + + def fail_request(*args, **kwargs): + raise AssertionError("requests.request should not be called") + + monkeypatch.setattr(mllm.requests, "request", fail_request) + + with pytest.raises(mllm.UnsafeRemoteURLError): + mllm.download_video("http://127.0.0.1/internal.mp4") + # ============================================================================= # MLLM Model Tests diff --git a/tests/test_mllm_continuous_batching.py b/tests/test_mllm_continuous_batching.py index 7cafb81f7..6a63e4b55 100644 --- a/tests/test_mllm_continuous_batching.py +++ b/tests/test_mllm_continuous_batching.py @@ -13,9 +13,11 @@ - Mixed text-only and multimodal requests """ +import asyncio import base64 import os import tempfile +from contextlib import nullcontext from unittest.mock import MagicMock import pytest @@ -35,6 +37,76 @@ TEST_IMAGE_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" +class TestMLLMPromptCacheEval: + def test_collects_kv_and_arrays_cache_tensors(self): + from vllm_mlx.mllm_batch_generator import _cache_eval_tensors + + kv_keys = object() + kv_values = object() + state_a = object() + state_b = object() + + class KVLikeCache: + keys = kv_keys + values = kv_values + + @property + def state(self): + raise AssertionError("KV cache state should not be read") + + class ArraysLikeCache: + state = [state_a, None, state_b] + + class EmptyKVLikeCache: + keys = None + values = None + + @property + def state(self): + raise AttributeError("empty KV cache has no state") + + assert _cache_eval_tensors( + [KVLikeCache(), ArraysLikeCache(), EmptyKVLikeCache()] + ) == [kv_keys, kv_values, state_a, state_b] + + def test_eval_prompt_cache_skips_empty_cache(self, monkeypatch): + from vllm_mlx.mllm_batch_generator import _eval_prompt_cache + + eval_mock = MagicMock() + monkeypatch.setattr(mx, "eval", eval_mock) + + class EmptyCache: + keys = None + values = None + state = [None] + + _eval_prompt_cache([EmptyCache()]) + + eval_mock.assert_not_called() + + def test_eval_prompt_cache_flattens_cache_tensors(self, monkeypatch): + from vllm_mlx.mllm_batch_generator import _eval_prompt_cache + + kv_keys = object() + kv_values = object() + state = object() + eval_mock = MagicMock() + monkeypatch.setattr(mx, "eval", eval_mock) + + class KVLikeCache: + keys = kv_keys + values = kv_values + + class ArraysLikeCache: + pass + + ArraysLikeCache.state = [state] + + _eval_prompt_cache([KVLikeCache(), ArraysLikeCache()]) + + eval_mock.assert_called_once_with(kv_keys, kv_values, state) + + def create_test_image(path: str, size: tuple = (32, 32)) -> str: """Create a test image file.""" try: @@ -234,6 +306,51 @@ def test_batch_filter(self): assert batch.uids == [1, 3] assert batch.request_ids == ["req-1", "req-3"] + def test_batch_extend_handles_empty_protocol_caches_without_keys(self): + """Caches with empty()/extend() but no .keys still need batch extension.""" + from vllm_mlx.mllm_batch_generator import MLLMBatch, MLLMBatchRequest + + class OpaqueCache: + def __init__(self): + self.extend_calls = 0 + self.extended_with = None + + def empty(self): + return False + + def extend(self, other): + self.extend_calls += 1 + self.extended_with = other + + primary_cache = OpaqueCache() + other_cache = OpaqueCache() + primary = MLLMBatch( + uids=[1], + request_ids=["req-1"], + y=mx.array([100]), + logprobs=[mx.array([0.1])], + max_tokens=[100], + num_tokens=[0], + cache=[primary_cache], + requests=[MLLMBatchRequest(uid=1, request_id="req-1", prompt="one")], + ) + other = MLLMBatch( + uids=[2], + request_ids=["req-2"], + y=mx.array([200]), + logprobs=[mx.array([0.2])], + max_tokens=[100], + num_tokens=[0], + cache=[other_cache], + requests=[MLLMBatchRequest(uid=2, request_id="req-2", prompt="two")], + ) + + primary.extend(other) + + assert primary.y.shape == (2,) + assert primary_cache.extend_calls == 1 + assert primary_cache.extended_with is other_cache + class TestMLLMBatchStats: """Tests for MLLMBatchStats.""" @@ -642,7 +759,528 @@ async def test_streaming(self, test_image_path): finally: await scheduler.stop() + async def test_stream_outputs_consumer_break_after_finished_does_not_abort(self): + """Breaking after a finished output is normal consumption, not orphaning.""" + from vllm_mlx.mllm_scheduler import MLLMScheduler + from vllm_mlx.request import RequestOutput + + scheduler = MLLMScheduler.__new__(MLLMScheduler) + scheduler.output_queues = {"req-1": asyncio.Queue()} + scheduler.abort_request = MagicMock(return_value=True) + + await scheduler.output_queues["req-1"].put( + RequestOutput( + request_id="req-1", + output_text="done", + finished=True, + finish_reason="stop", + ) + ) + + stream = MLLMScheduler.stream_outputs(scheduler, "req-1") + output = await stream.__anext__() + assert output.finished is True + await stream.aclose() + + scheduler.abort_request.assert_not_called() + assert "req-1" not in scheduler.output_queues + # Run tests if __name__ == "__main__": pytest.main([__file__, "-v"]) + + +class TestMLLMBatchGeneratorMTPGuards: + def test_process_prompts_applies_request_sampling_to_first_token(self, monkeypatch): + from vllm_mlx.mllm_batch_generator import ( + MLLMBatchGenerator, + MLLMBatchRequest, + MLLMBatchStats, + ) + + class FakeCache: + def merge(self, caches): + return self + + class RecordingProcessor: + def __init__(self): + self.calls = [] + + def __call__(self, tokens, logits): + self.calls.append(tokens.tolist()) + return logits + + processor = RecordingProcessor() + request_sampler = MagicMock(return_value=mx.array([3], dtype=mx.uint32)) + fallback_sampler = MagicMock(return_value=mx.array([1], dtype=mx.uint32)) + sampler_calls = [] + + def fake_make_sampler(**kwargs): + sampler_calls.append(kwargs) + return request_sampler + + monkeypatch.setattr(mx, "stream", lambda stream: nullcontext()) + monkeypatch.setattr( + "mlx_lm.models.cache.make_prompt_cache", lambda model: [FakeCache()] + ) + monkeypatch.setattr("mlx_lm.sample_utils.make_sampler", fake_make_sampler) + monkeypatch.setattr( + "mlx_lm.sample_utils.make_logits_processors", lambda **_: [] + ) + + generator = MLLMBatchGenerator.__new__(MLLMBatchGenerator) + generator._stats = MLLMBatchStats() + generator._pending_error_responses = [] + generator._aborted_request_ids = set() + generator._prefill_progress = {} + generator.prefix_cache = None + generator.prefill_step_size = 512 + generator.language_model = object() + generator.model = MagicMock() + generator.sampler = fallback_sampler + generator._trim_rotating_caches = lambda cache: None + generator._preprocess_request = lambda req: None + generator._run_chunked_text_prefill = lambda req, cache: mx.array( + [[[0.0, 1.0, 2.0, 3.0]]] + ) + + request = MLLMBatchRequest( + uid=7, + request_id="req-7", + prompt="hello", + temperature=0.3, + top_p=0.8, + top_k=0, + min_p=0.0, + logits_processors=[processor], + ) + request.input_ids = mx.array([[42]], dtype=mx.uint32) + request.is_text_only = True + + batch = MLLMBatchGenerator._process_prompts(generator, [request]) + + assert batch.y.tolist() == [3] + assert batch.samplers == [request_sampler] + assert batch.logits_processors == [[processor]] + assert processor.calls == [[]] + assert request_sampler.call_count == 1 + fallback_sampler.assert_not_called() + assert sampler_calls == [{"temp": 0.3, "top_p": 0.8, "top_k": 0, "min_p": 0.0}] + + def test_next_passes_current_token_to_logits_processor_prefix(self): + from vllm_mlx.mllm_batch_generator import ( + MLLMBatch, + MLLMBatchGenerator, + MLLMBatchRequest, + MLLMBatchStats, + ) + + captured = {} + + def fake_step(input_tokens, cache, logits_processors, output_tokens, samplers): + captured["input_tokens"] = input_tokens.tolist() + captured["output_tokens"] = output_tokens + return mx.array([11]), [mx.array([0.2, 0.8])] + + generator = MLLMBatchGenerator.__new__(MLLMBatchGenerator) + generator._stats = MLLMBatchStats() + generator.stop_tokens = set() + generator.unprocessed_requests = [] + generator._pending_error_responses = [] + generator._prefill_progress = {} + generator.prefix_cache = None + generator._maybe_store_prefix_cache = lambda batch, end_idx: None + generator._step = fake_step + + processor = MagicMock() + request = MLLMBatchRequest(uid=1, request_id="req-1", prompt="hello") + request.output_tokens = [5] + generator.active_batch = MLLMBatch( + uids=[1], + request_ids=["req-1"], + y=mx.array([7]), + logprobs=[mx.array([0.5, 0.5])], + max_tokens=[8], + num_tokens=[1], + cache=[], + requests=[request], + logits_processors=[[processor]], + samplers=None, + ) + + responses = MLLMBatchGenerator._next(generator) + + assert [r.token for r in responses] == [7] + assert captured["input_tokens"] == [[7]] + assert captured["output_tokens"] == [[5, 7]] + assert request.output_tokens == [5, 7] + + def test_install_mtp_mllm_disables_mtp_when_logits_processors_active(self): + from vllm_mlx.mllm_batch_generator import install_mtp_mllm + + expected_tokens = mx.array([7]) + expected_logprobs = [mx.array([0.1, 0.9])] + original_step = MagicMock(return_value=(expected_tokens, expected_logprobs)) + + class FakeBatchGen: + def __init__(self): + self._step = original_step + self._next = MagicMock(return_value=[]) + self.active_batch = MagicMock() + self.active_batch.__len__.return_value = 1 + self.active_batch.requests = [ + MagicMock( + temperature=0.0, + top_p=1.0, + top_k=0, + min_p=0.0, + ) + ] + self.sampler = MagicMock() + + batch_gen = FakeBatchGen() + language_model = MagicMock() + + install_mtp_mllm(batch_gen, language_model, num_draft_tokens=4) + + logits_processor = MagicMock() + tokens, logprobs = batch_gen._step( + mx.array([[123]]), + cache=[], + logits_processors=[[logits_processor]], + output_tokens=[[1, 2]], + samplers=[None], + ) + + assert tokens.tolist() == expected_tokens.tolist() + assert [lp.tolist() for lp in logprobs] == [ + lp.tolist() for lp in expected_logprobs + ] + original_step.assert_called_once() + language_model.assert_not_called() + language_model.mtp_forward.assert_not_called() + + def test_install_mtp_mllm_disables_mtp_for_non_greedy_sampling(self): + from vllm_mlx.mllm_batch_generator import install_mtp_mllm + + expected_tokens = mx.array([11]) + expected_logprobs = [mx.array([0.3, 0.7])] + original_step = MagicMock(return_value=(expected_tokens, expected_logprobs)) + + class FakeBatchGen: + def __init__(self): + self._step = original_step + self._next = MagicMock(return_value=[]) + self.active_batch = MagicMock() + self.active_batch.__len__.return_value = 1 + self.active_batch.requests = [ + MagicMock( + temperature=0.6, + top_p=0.95, + top_k=20, + min_p=0.0, + ) + ] + self.sampler = MagicMock() + + batch_gen = FakeBatchGen() + language_model = MagicMock() + + install_mtp_mllm(batch_gen, language_model, num_draft_tokens=4) + + tokens, logprobs = batch_gen._step( + mx.array([[321]]), + cache=[], + logits_processors=None, + output_tokens=None, + samplers=[MagicMock()], + ) + + assert tokens.tolist() == expected_tokens.tolist() + assert [lp.tolist() for lp in logprobs] == [ + lp.tolist() for lp in expected_logprobs + ] + original_step.assert_called_once() + language_model.assert_not_called() + language_model.mtp_forward.assert_not_called() + + def test_install_mtp_mllm_accepted_drafts_bypass_request_sampler(self): + from vllm_mlx.mllm_batch_generator import MLLMBatchResponse, install_mtp_mllm + + class FakeBatchGen: + def __init__(self): + self._step = MagicMock() + self._next = MagicMock( + return_value=[ + MLLMBatchResponse( + uid=7, + request_id="req-7", + token=1, + logprobs=mx.array([0.0, 0.0, 0.0, 0.0, 0.0]), + finish_reason=None, + ) + ] + ) + self.active_batch = MagicMock() + self.active_batch.__len__.return_value = 1 + self.active_batch.uids = [7] + request = MagicMock( + request_id="req-7", + temperature=0.0, + top_p=1.0, + top_k=0, + min_p=0.0, + output_tokens=[], + ) + self.active_batch.requests = [request] + self.active_batch.num_tokens = [0] + self.active_batch.max_tokens = [16] + self.stop_tokens = set() + self.sampler = MagicMock(return_value=mx.array([1], dtype=mx.uint32)) + self._maybe_store_prefix_cache = MagicMock() + + batch_gen = FakeBatchGen() + request_sampler = MagicMock(return_value=mx.array([1], dtype=mx.uint32)) + + class FakeLanguageModel: + def mtp_forward(self, hidden_states, next_token_ids, mtp_cache=None): + logits = mx.full((1, 1, 5), -1000.0) + logits[:, :, 2] = 0.0 + return logits + + def __call__(self, verify_input, cache=None, return_hidden=False): + logits = mx.full((1, 2, 5), -1000.0) + logits[:, 0, 2] = 0.0 + logits[:, 1, 3] = 0.0 + return logits, mx.zeros((1, 2, 4)) + + install_mtp_mllm(batch_gen, FakeLanguageModel(), num_draft_tokens=1) + + batch_gen._step( + mx.array([[123]], dtype=mx.uint32), + cache=[], + logits_processors=None, + output_tokens=[[]], + samplers=[request_sampler], + ) + responses = batch_gen._next() + + assert [r.token for r in responses] == [1, 2] + assert request_sampler.call_count == 1 + assert batch_gen.sampler.call_count == 0 + + def test_next_keeps_retired_processors_by_default(self, monkeypatch): + from vllm_mlx.mllm_batch_generator import ( + MLLMBatch, + MLLMBatchGenerator, + MLLMBatchRequest, + MLLMBatchStats, + ) + + monkeypatch.delenv("VLLM_MLX_ENABLE_THINKING_RETIREMENT_RESUME", raising=False) + + class RetiredProcessor: + is_retired = True + + def __call__(self, tokens, logits): + return logits + + processor = RetiredProcessor() + generator = MLLMBatchGenerator.__new__(MLLMBatchGenerator) + generator._stats = MLLMBatchStats() + generator.stop_tokens = set() + generator.unprocessed_requests = [] + generator._pending_error_responses = [] + generator._prefill_progress = {} + generator.prefix_cache = None + generator._maybe_store_prefix_cache = lambda batch, end_idx: None + generator._step = lambda *args, **kwargs: ( + mx.array([11]), + [mx.array([0.2, 0.8])], + ) + + request = MLLMBatchRequest(uid=1, request_id="req-1", prompt="hello") + generator.active_batch = MLLMBatch( + uids=[1], + request_ids=["req-1"], + y=mx.array([7]), + logprobs=[mx.array([0.5, 0.5])], + max_tokens=[4], + num_tokens=[0], + cache=[], + requests=[request], + logits_processors=[[processor]], + samplers=None, + ) + + responses = MLLMBatchGenerator._next(generator) + + assert len(responses) == 1 + assert generator.active_batch is not None + assert generator.active_batch.logits_processors == [[processor]] + + def test_next_drops_retired_processors_only_when_enabled(self, monkeypatch): + from vllm_mlx.mllm_batch_generator import ( + MLLMBatch, + MLLMBatchGenerator, + MLLMBatchRequest, + MLLMBatchStats, + ) + + monkeypatch.setenv("VLLM_MLX_ENABLE_THINKING_RETIREMENT_RESUME", "1") + + class RetiredProcessor: + is_retired = True + + def __call__(self, tokens, logits): + return logits + + generator = MLLMBatchGenerator.__new__(MLLMBatchGenerator) + generator._stats = MLLMBatchStats() + generator.stop_tokens = set() + generator.unprocessed_requests = [] + generator._pending_error_responses = [] + generator._prefill_progress = {} + generator.prefix_cache = None + generator._maybe_store_prefix_cache = lambda batch, end_idx: None + generator._step = lambda *args, **kwargs: ( + mx.array([11]), + [mx.array([0.2, 0.8])], + ) + + request = MLLMBatchRequest(uid=1, request_id="req-1", prompt="hello") + generator.active_batch = MLLMBatch( + uids=[1], + request_ids=["req-1"], + y=mx.array([7]), + logprobs=[mx.array([0.5, 0.5])], + max_tokens=[4], + num_tokens=[0], + cache=[], + requests=[request], + logits_processors=[[RetiredProcessor()]], + samplers=None, + ) + + responses = MLLMBatchGenerator._next(generator) + + assert len(responses) == 1 + assert generator.active_batch is not None + assert generator.active_batch.logits_processors == [None] + + +class TestBatchedMLLMConfigWiring: + def test_batched_engine_forwards_prefill_step_size_to_mllm_scheduler( + self, monkeypatch + ): + from vllm_mlx.engine.batched import BatchedEngine + from vllm_mlx.scheduler import SchedulerConfig + + captured = {} + + class FakeMLXMultimodalLM: + def __init__(self, model_name, trust_remote_code=True): + self.model_name = model_name + self.model = object() + self.processor = object() + + def load(self): + return None + + class FakeMLLMSchedulerConfig: + def __init__(self, **kwargs): + captured["config_kwargs"] = kwargs + self.__dict__.update(kwargs) + + class FakeMLLMScheduler: + def __init__(self, model, processor, config): + captured["scheduler_config"] = config + + async def start(self): + return None + + import vllm_mlx.engine.batched as batched_mod + import vllm_mlx.mllm_scheduler as mllm_sched_mod + import vllm_mlx.models.mllm as mllm_model_mod + + monkeypatch.setattr(mllm_model_mod, "MLXMultimodalLM", FakeMLXMultimodalLM) + monkeypatch.setattr(mllm_sched_mod, "MLLMScheduler", FakeMLLMScheduler) + monkeypatch.setattr( + mllm_sched_mod, "MLLMSchedulerConfig", FakeMLLMSchedulerConfig + ) + monkeypatch.setattr( + batched_mod.BatchedEngine, "_inject_mtp_mllm", lambda self: None + ) + + cfg = SchedulerConfig( + prefill_batch_size=4, + completion_batch_size=8, + prefill_step_size=256, + enable_mtp=False, + ) + engine = BatchedEngine( + model_name="fake-qwen", + scheduler_config=cfg, + force_mllm=True, + ) + + asyncio.run(engine._start_mllm()) + + assert captured["config_kwargs"]["prefill_step_size"] == 256 + + +class TestPreprocessIdempotent: + """_preprocess_request must be idempotent for text-only requests. + + The scheduler offloads preprocessing to a thread-pool executor so + the event loop stays responsive. _process_prompts then calls + _preprocess_request again — the second call must be a no-op. + """ + + def test_text_only_not_preprocessed_twice(self): + """When input_ids is already set (executor did it), skip.""" + from vllm_mlx.mllm_batch_generator import MLLMBatchRequest + + req = MLLMBatchRequest( + uid=0, + prompt="Hello", + request_id="test-idem", + ) + # Simulate executor having already set input_ids + req.input_ids = mx.array([[1, 2, 3]]) + + # Build a minimal batch generator with the method + from vllm_mlx.mllm_batch_generator import MLLMBatchGenerator + + gen = MLLMBatchGenerator.__new__(MLLMBatchGenerator) + gen._preprocess_request = MLLMBatchGenerator._preprocess_request.__get__( + gen, MLLMBatchGenerator + ) + + # Must return immediately without touching prepare_inputs + gen._preprocess_request(req) + assert req.input_ids.shape == (1, 3) + + def test_vision_request_not_skipped(self): + """Vision requests should NOT be skipped even with input_ids set.""" + from vllm_mlx.mllm_batch_generator import MLLMBatchRequest + + req = MLLMBatchRequest( + uid=0, + prompt="Describe", + request_id="test-vis", + images=["fake.png"], + ) + req.input_ids = mx.array([[1, 2, 3]]) + + from vllm_mlx.mllm_batch_generator import MLLMBatchGenerator + + gen = MLLMBatchGenerator.__new__(MLLMBatchGenerator) + gen._preprocess_request = MLLMBatchGenerator._preprocess_request.__get__( + gen, MLLMBatchGenerator + ) + + # Should NOT return early — will try to import prepare_inputs + with pytest.raises(Exception): + gen._preprocess_request(req) diff --git a/tests/test_mllm_mtp_routing.py b/tests/test_mllm_mtp_routing.py index e2394cf69..69a38a1e4 100644 --- a/tests/test_mllm_mtp_routing.py +++ b/tests/test_mllm_mtp_routing.py @@ -3,13 +3,13 @@ def test_has_media_content_text_only(): - from vllm_mlx.engine.simple import _has_media_content + from vllm_mlx.api.utils import has_media_content as _has_media_content assert _has_media_content([{"role": "user", "content": "Hello"}]) is False def test_has_media_content_with_image(): - from vllm_mlx.engine.simple import _has_media_content + from vllm_mlx.api.utils import has_media_content as _has_media_content messages = [ { @@ -26,8 +26,22 @@ def test_has_media_content_with_image(): assert _has_media_content(messages) is True +def test_has_media_content_with_local_image_part(): + from vllm_mlx.api.utils import has_media_content as _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": "/tmp/frame.png"}, + ], + } + ] + assert _has_media_content(messages) is True + + def test_has_media_content_with_video(): - from vllm_mlx.engine.simple import _has_media_content + from vllm_mlx.api.utils import has_media_content as _has_media_content messages = [ { @@ -40,21 +54,35 @@ def test_has_media_content_with_video(): assert _has_media_content(messages) is True +def test_has_media_content_with_local_video_part(): + from vllm_mlx.api.utils import has_media_content as _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "video", "video": "/tmp/v.mp4"}, + ], + } + ] + assert _has_media_content(messages) is True + + def test_has_media_content_empty(): - from vllm_mlx.engine.simple import _has_media_content + from vllm_mlx.api.utils import has_media_content as _has_media_content assert _has_media_content([]) is False def test_has_media_content_string_content(): """String content (not list) should return False.""" - from vllm_mlx.engine.simple import _has_media_content + from vllm_mlx.api.utils import has_media_content as _has_media_content assert _has_media_content([{"role": "user", "content": "Just text"}]) is False def test_has_media_content_audio(): - from vllm_mlx.engine.simple import _has_media_content + from vllm_mlx.api.utils import has_media_content as _has_media_content messages = [ { @@ -69,7 +97,7 @@ def test_has_media_content_audio(): def test_has_media_content_multi_turn(): """Media in earlier turns should still be detected.""" - from vllm_mlx.engine.simple import _has_media_content + from vllm_mlx.api.utils import has_media_content as _has_media_content messages = [ { @@ -90,7 +118,7 @@ def test_has_media_content_multi_turn(): def test_has_media_content_text_list(): """List content with only text parts should return False.""" - from vllm_mlx.engine.simple import _has_media_content + from vllm_mlx.api.utils import has_media_content as _has_media_content messages = [ { @@ -104,6 +132,37 @@ def test_has_media_content_text_list(): assert _has_media_content(messages) is False +def test_has_media_content_with_message_models(): + """Pydantic message models should follow the attribute-based content path.""" + from vllm_mlx.api.models import ContentPart, Message, VideoUrl + from vllm_mlx.api.utils import has_media_content as _has_media_content + + messages = [ + Message( + role="user", + content=[ + ContentPart(type="text", text="Describe this clip"), + ContentPart( + type="video_url", + video_url=VideoUrl(url="https://example.com/video.mp4"), + ), + ], + ) + ] + + assert _has_media_content(messages) is True + + +def test_has_media_content_none_content(): + """A schema-valid message with None content should not count as media.""" + from vllm_mlx.api.models import Message + from vllm_mlx.api.utils import has_media_content as _has_media_content + + messages = [Message(role="system", content=None)] + + assert _has_media_content(messages) is False + + # --- MLXMultimodalLM extraction method tests --- from unittest.mock import MagicMock diff --git a/tests/test_native_tool_format.py b/tests/test_native_tool_format.py index b3625b3ec..a43c2c6db 100644 --- a/tests/test_native_tool_format.py +++ b/tests/test_native_tool_format.py @@ -40,6 +40,7 @@ def test_parsers_with_native_support(self): KimiToolParser, HermesToolParser, Glm47ToolParser, + QwenToolParser, ] for parser_cls in native_parsers: assert ( @@ -52,7 +53,6 @@ def test_parsers_with_native_support(self): def test_parsers_without_native_support(self): """Parsers that don't support native tool format should return False.""" non_native_parsers = [ - QwenToolParser, NemotronToolParser, xLAMToolParser, AutoToolParser, @@ -78,6 +78,7 @@ def test_via_manager(self): "kimi", "hermes", "glm47", + "qwen", ]: parser_cls = ToolParserManager.get_tool_parser(name) assert ( @@ -85,7 +86,7 @@ def test_via_manager(self): ), f"Parser '{name}' should support native format" # No native support - for name in ["qwen", "nemotron", "xlam", "auto"]: + for name in ["nemotron", "xlam", "auto"]: parser_cls = ToolParserManager.get_tool_parser(name) assert ( parser_cls.supports_native_format() is False diff --git a/tests/test_prompt_warmup.py b/tests/test_prompt_warmup.py new file mode 100644 index 000000000..c6585052c --- /dev/null +++ b/tests/test_prompt_warmup.py @@ -0,0 +1,670 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for vllm_mlx.prompt_warmup. + +Covers: +- load_warmup_file: validation of shape, content, error paths +- warm_prefix_cache: happy path + error handling with a stub engine +- Agent/code-assistant scenarios that match real code-agent workloads +""" + +from __future__ import annotations + +import asyncio +import json +from pathlib import Path +from typing import Any + +import pytest + +from vllm_mlx.prompt_warmup import ( + _build_strict_prefix_string, + _ensure_user_terminator, + load_warmup_file, + warm_prefix_cache, +) + +# --------------------------------------------------------------------------- +# Realistic agent system prompts (code-agent-style, ~1-2k tokens each) +# --------------------------------------------------------------------------- + +AGENT_SYSTEM_PROMPT_CODING_BASE = """You are a code assistant running inside a developer's terminal. + +Rules of engagement: +- Read files before editing. Never guess at code behaviour. +- Prefer small, focused edits over sweeping refactors. +- Match the project's existing conventions. +- Add tests that would catch the regression you are fixing. +- Do not add comments that restate what the code obviously does. +- Do not introduce new dependencies without flagging. + +Safety: +- Never generate secrets or credentials. +- Refuse destructive actions that affect shared state without confirmation. +- Sanitize input before interpolating into shell, SQL, HTML, regex. + +Tools: Read, Write, Edit, Bash, Grep, Glob, Agent. +Keep final responses concise. Use file_path:line_number for code references.""" + +AGENT_SYSTEM_PROMPT_CODING = AGENT_SYSTEM_PROMPT_CODING_BASE * 3 # ~1.5k tokens + + +AGENT_SYSTEM_PROMPT_REVIEWER_BASE = """You are a senior code reviewer. + +Evaluate diffs for: +- Correctness: off-by-one, race conditions, null handling, error paths. +- Security: input validation, authn/authz, injection, secrets leakage. +- Performance: N+1 queries, unbounded loops, memory leaks. +- Maintainability: naming, layering, testability, complexity. +- Conventions: does this match the rest of the codebase? + +For each finding, cite file_path:line_number and explain WHY it matters. +Prioritize by severity: blocker, major, minor, nit. +Do not nitpick formatting that a linter would catch.""" + +AGENT_SYSTEM_PROMPT_REVIEWER = AGENT_SYSTEM_PROMPT_REVIEWER_BASE * 4 # ~2k tokens + + +AGENT_CODING_MESSAGES_1 = [{"role": "system", "content": AGENT_SYSTEM_PROMPT_CODING}] + +AGENT_REVIEWER_MESSAGES = [{"role": "system", "content": AGENT_SYSTEM_PROMPT_REVIEWER}] + +AGENT_CONVERSATION_HISTORY = [ + {"role": "system", "content": AGENT_SYSTEM_PROMPT_CODING}, + { + "role": "user", + "content": ( + "I'm getting a TypeError in my LRU cache:\n\n" + "```python\n" + "class LRUCache:\n" + " def __init__(self, max_size):\n" + " self.max_size = max_size\n" + " self.data = {}\n" + "\n" + " def get(self, key):\n" + " with self._lock:\n" + " return self.data.get(key)\n" + "```\n\n" + "Error: `AttributeError: 'LRUCache' object has no attribute '_lock'`. " + "What's wrong?" + ), + }, + { + "role": "assistant", + "content": ( + "`_lock` is referenced in `get()` but never created in `__init__`.\n\n" + "Fix: add `self._lock = threading.Lock()` in `__init__` and `import threading` " + "at the top. Also apply the lock on any other mutating methods " + "(`put`, `evict`) so the cache is thread-safe.\n\n" + "Want me to patch the file?" + ), + }, +] + + +# --------------------------------------------------------------------------- +# Stub engine (mimics BaseEngine.stream_chat interface) +# --------------------------------------------------------------------------- + + +class _FakeOutput: + def __init__(self, text: str, finished: bool, prompt_tokens: int): + self.text = text + self.new_text = text + self.prompt_tokens = prompt_tokens + self.completion_tokens = 1 + self.finished = finished + self.finish_reason = "stop" if finished else None + + +class _StubEngine: + """Mimics the minimal interface warm_prefix_cache expects.""" + + def __init__( + self, + *, + raise_on: int | None = None, + prompt_tokens_per_call: int = 100, + ) -> None: + self.calls: list[list[dict[str, Any]]] = [] + self.raise_on = raise_on + self.prompt_tokens_per_call = prompt_tokens_per_call + + async def stream_chat( + self, + *, + messages: list[dict[str, Any]], + max_tokens: int, + temperature: float, + **_: Any, + ): + self.calls.append(messages) + if self.raise_on is not None and len(self.calls) - 1 == self.raise_on: + raise RuntimeError("simulated engine failure") + yield _FakeOutput( + "ok", finished=True, prompt_tokens=self.prompt_tokens_per_call + ) + + +# --------------------------------------------------------------------------- +# load_warmup_file +# --------------------------------------------------------------------------- + + +def test_load_warmup_file_missing(tmp_path: Path): + missing = tmp_path / "nope.json" + with pytest.raises(FileNotFoundError): + load_warmup_file(str(missing)) + + +def test_load_warmup_file_not_a_list(tmp_path: Path): + p = tmp_path / "w.json" + p.write_text('{"role": "system", "content": "x"}') + with pytest.raises(ValueError, match="top-level JSON list"): + load_warmup_file(str(p)) + + +def test_load_warmup_file_empty_list(tmp_path: Path): + p = tmp_path / "w.json" + p.write_text("[]") + with pytest.raises(ValueError, match="empty"): + load_warmup_file(str(p)) + + +def test_load_warmup_file_entry_not_list(tmp_path: Path): + p = tmp_path / "w.json" + p.write_text('[{"role": "system", "content": "x"}]') + with pytest.raises(ValueError, match="non-empty list of message dicts"): + load_warmup_file(str(p)) + + +def test_load_warmup_file_entry_empty_list(tmp_path: Path): + p = tmp_path / "w.json" + p.write_text("[[]]") + with pytest.raises(ValueError, match="non-empty list"): + load_warmup_file(str(p)) + + +def test_load_warmup_file_message_missing_keys(tmp_path: Path): + p = tmp_path / "w.json" + p.write_text('[[{"role": "system"}]]') + with pytest.raises(ValueError, match="missing 'role' or 'content'"): + load_warmup_file(str(p)) + + +def test_load_warmup_file_valid_single(tmp_path: Path): + p = tmp_path / "w.json" + data = [AGENT_CODING_MESSAGES_1] + p.write_text(json.dumps(data)) + loaded = load_warmup_file(str(p)) + assert loaded == data + + +def test_load_warmup_file_valid_multi_agent(tmp_path: Path): + """Real code-agent scenarios: coding agent, reviewer, and conversation history.""" + p = tmp_path / "w.json" + data = [ + AGENT_CODING_MESSAGES_1, + AGENT_REVIEWER_MESSAGES, + AGENT_CONVERSATION_HISTORY, + ] + p.write_text(json.dumps(data)) + loaded = load_warmup_file(str(p)) + assert len(loaded) == 3 + assert loaded[0][0]["role"] == "system" + assert loaded[2][-1]["role"] == "assistant" + + +def test_load_warmup_file_expands_user_home( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): + """~/ in the path should resolve to $HOME.""" + monkeypatch.setenv("HOME", str(tmp_path)) + p = tmp_path / "w.json" + p.write_text(json.dumps([AGENT_CODING_MESSAGES_1])) + # Use a relative-to-$HOME path + loaded = load_warmup_file("~/w.json") + assert len(loaded) == 1 + + +# --------------------------------------------------------------------------- +# warm_prefix_cache — agent workload scenarios +# --------------------------------------------------------------------------- + + +def test_warm_prefix_cache_single_coding_agent(): + """Warm up a single code-agent-style system prompt.""" + engine = _StubEngine(prompt_tokens_per_call=400) + result = asyncio.run(warm_prefix_cache(engine, [AGENT_CODING_MESSAGES_1])) + assert result["count"] == 1 + assert result["skipped"] == 0 + assert result["total_prompt_tokens"] == 400 + # System-only prompt gets a user-terminator appended (required by some templates). + assert engine.calls[0][0] == AGENT_CODING_MESSAGES_1[0] + assert engine.calls[0][-1]["role"] == "user" + + +def test_warm_prefix_cache_multi_agent_personas(): + """Realistic multi-agent deployment: coding + reviewer + conversation history.""" + engine = _StubEngine(prompt_tokens_per_call=500) + prompts = [ + AGENT_CODING_MESSAGES_1, + AGENT_REVIEWER_MESSAGES, + AGENT_CONVERSATION_HISTORY, + ] + result = asyncio.run(warm_prefix_cache(engine, prompts)) + assert result["count"] == 3 + assert result["skipped"] == 0 + assert result["total_prompt_tokens"] == 1500 + # Two system-only prompts get user-terminators; the conversation history + # (which already ends with assistant) passes through untouched. + # Calls arrive in gather order — index by system-content match. + for prompt in prompts: + match = [c for c in engine.calls if c and c[0] == prompt[0]] + assert match, f"no call matched system content of {prompt[0]}" + # History prompt ends in assistant → untouched + hist_match = [c for c in engine.calls if len(c) == len(AGENT_CONVERSATION_HISTORY)] + assert hist_match and hist_match[0] == AGENT_CONVERSATION_HISTORY + + +def test_warm_prefix_cache_handles_individual_failure(): + """One prompt failing must not abort the rest of the warm-up.""" + engine = _StubEngine(raise_on=1, prompt_tokens_per_call=200) + prompts = [ + AGENT_CODING_MESSAGES_1, + AGENT_REVIEWER_MESSAGES, + AGENT_CONVERSATION_HISTORY, + ] + result = asyncio.run(warm_prefix_cache(engine, prompts)) + assert result["count"] == 2 + assert result["skipped"] == 1 + # total_prompt_tokens reflects only successful calls + assert result["total_prompt_tokens"] == 400 + + +def test_warm_prefix_cache_empty_list(): + """Empty prompt list returns zeros and does not touch the engine.""" + engine = _StubEngine() + result = asyncio.run(warm_prefix_cache(engine, [])) + assert result["count"] == 0 + assert result["skipped"] == 0 + assert result["total_prompt_tokens"] == 0 + assert engine.calls == [] + + +def test_warm_prefix_cache_reports_elapsed_time(): + """elapsed_ms should be >= 0 and not negative.""" + engine = _StubEngine(prompt_tokens_per_call=100) + result = asyncio.run(warm_prefix_cache(engine, [AGENT_CODING_MESSAGES_1])) + assert result["elapsed_ms"] >= 0 + + +def test_warm_prefix_cache_uses_max_tokens_1_by_default(): + """The warmer must not waste GPU time generating output — max_tokens defaults to 1.""" + captured_max_tokens: list[int] = [] + + class _Recording(_StubEngine): + async def stream_chat(self, *, messages, max_tokens, temperature, **_): + captured_max_tokens.append(max_tokens) + self.calls.append(messages) + yield _FakeOutput("x", finished=True, prompt_tokens=100) + + engine = _Recording() + asyncio.run(warm_prefix_cache(engine, [AGENT_CODING_MESSAGES_1])) + assert captured_max_tokens == [1] + + +def test_warm_prefix_cache_all_fail_does_not_raise(): + """If every prompt errors, the warmer reports skipped=N but does not crash.""" + engine = _StubEngine(raise_on=0, prompt_tokens_per_call=100) + # single prompt that errors + result = asyncio.run(warm_prefix_cache(engine, [AGENT_CODING_MESSAGES_1])) + assert result["count"] == 0 + assert result["skipped"] == 1 + assert result["total_prompt_tokens"] == 0 + + +# --------------------------------------------------------------------------- +# load → warm end-to-end (file → stub engine) +# --------------------------------------------------------------------------- + + +def test_end_to_end_file_then_warm(tmp_path: Path): + """Loading a warm-up file and running it through the warmer produces the + exact messages the engine would receive from a real HTTP request.""" + p = tmp_path / "agents.json" + prompts = [ + AGENT_CODING_MESSAGES_1, + AGENT_REVIEWER_MESSAGES, + ] + p.write_text(json.dumps(prompts)) + + loaded = load_warmup_file(str(p)) + engine = _StubEngine(prompt_tokens_per_call=750) + + result = asyncio.run(warm_prefix_cache(engine, loaded)) + assert result["count"] == 2 + assert result["total_prompt_tokens"] == 1500 + # The warmer appends a minimal user message for system-only prompts so + # templates that require a user (Qwen3.6, DeepSeek-VL) don't error. + # Verify the system content survives the transformation. + assert engine.calls[0][0] == prompts[0][0] + assert engine.calls[1][0] == prompts[1][0] + + +# --------------------------------------------------------------------------- +# Optimization 1 — user-terminator auto-append (Qwen3.6 / DeepSeek-VL fix) +# --------------------------------------------------------------------------- + + +def test_ensure_user_terminator_appends_for_system_only(): + msgs = [{"role": "system", "content": "SYS"}] + out = _ensure_user_terminator(msgs) + assert out == [ + {"role": "system", "content": "SYS"}, + {"role": "user", "content": " "}, + ] + + +def test_ensure_user_terminator_preserves_trailing_user(): + msgs = [ + {"role": "system", "content": "SYS"}, + {"role": "user", "content": "hi"}, + ] + out = _ensure_user_terminator(msgs) + assert out == msgs # untouched + + +def test_ensure_user_terminator_preserves_trailing_assistant(): + msgs = [ + {"role": "system", "content": "SYS"}, + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + out = _ensure_user_terminator(msgs) + assert out == msgs # untouched — conversation history is fine + + +def test_warm_prefix_cache_appends_user_for_system_only(): + """Real behaviour: system-only prompt gets a user appended before going + to the engine — otherwise Qwen3.6-style templates raise TemplateError.""" + engine = _StubEngine(prompt_tokens_per_call=500) + asyncio.run(warm_prefix_cache(engine, [AGENT_CODING_MESSAGES_1])) + # Engine received [system, user-placeholder] + assert len(engine.calls[0]) == 2 + assert engine.calls[0][0]["role"] == "system" + assert engine.calls[0][1]["role"] == "user" + + +# --------------------------------------------------------------------------- +# Optimization 2 — parallel warm-up (asyncio.gather) +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Edge cases — load_warmup_file +# --------------------------------------------------------------------------- + + +def test_load_warmup_file_invalid_json(tmp_path: Path): + p = tmp_path / "bad.json" + p.write_text("{not valid json") + with pytest.raises(json.JSONDecodeError): + load_warmup_file(str(p)) + + +def test_load_warmup_file_scalar_top_level(tmp_path: Path): + p = tmp_path / "str.json" + p.write_text('"just a string"') + with pytest.raises(ValueError, match="top-level JSON list"): + load_warmup_file(str(p)) + + +def test_load_warmup_file_number_top_level(tmp_path: Path): + p = tmp_path / "num.json" + p.write_text("42") + with pytest.raises(ValueError, match="top-level JSON list"): + load_warmup_file(str(p)) + + +def test_load_warmup_file_non_dict_message(tmp_path: Path): + p = tmp_path / "bad.json" + p.write_text('[[["role", "system"]]]') # list of [list of list] + with pytest.raises(ValueError, match="expected dict"): + load_warmup_file(str(p)) + + +def test_load_warmup_file_unicode_content(tmp_path: Path): + """Unicode in content survives round-trip.""" + p = tmp_path / "u.json" + content = "Héllo 世界 🌍 ñoño" + payload = [[{"role": "system", "content": content}]] + p.write_text(json.dumps(payload, ensure_ascii=False)) + loaded = load_warmup_file(str(p)) + assert loaded[0][0]["content"] == content + + +def test_load_warmup_file_many_prompts(tmp_path: Path): + """100 prompts load cleanly.""" + p = tmp_path / "many.json" + payload = [AGENT_CODING_MESSAGES_1 for _ in range(100)] + p.write_text(json.dumps(payload)) + loaded = load_warmup_file(str(p)) + assert len(loaded) == 100 + + +# --------------------------------------------------------------------------- +# Edge cases — _ensure_user_terminator +# --------------------------------------------------------------------------- + + +def test_ensure_user_terminator_empty_list(): + """Empty list → append user (harmless; won't be sent to engine in practice).""" + out = _ensure_user_terminator([]) + assert out == [{"role": "user", "content": " "}] + + +def test_ensure_user_terminator_tool_role(): + """Trailing tool message → we still append user (tool is not user/assistant).""" + msgs = [ + {"role": "system", "content": "SYS"}, + {"role": "user", "content": "call the tool"}, + {"role": "tool", "name": "calc", "content": "42"}, + ] + out = _ensure_user_terminator(msgs) + assert out[-1] == {"role": "user", "content": " "} + assert len(out) == len(msgs) + 1 + + +def test_ensure_user_terminator_only_user(): + """List with only user (no system) — preserved untouched.""" + msgs = [{"role": "user", "content": "just a user message"}] + out = _ensure_user_terminator(msgs) + assert out == msgs + + +# --------------------------------------------------------------------------- +# Edge cases — warm_prefix_cache +# --------------------------------------------------------------------------- + + +def test_warm_prefix_cache_engine_never_finishes(): + """Engine yields outputs but never sets finished=True → counted as skipped.""" + + class _NeverFinish(_StubEngine): + async def stream_chat(self, *, messages, max_tokens, temperature, **_): + self.calls.append(messages) + yield _FakeOutput("a", finished=False, prompt_tokens=100) + yield _FakeOutput("b", finished=False, prompt_tokens=100) + + engine = _NeverFinish() + result = asyncio.run(warm_prefix_cache(engine, [AGENT_CODING_MESSAGES_1])) + assert result["count"] == 0 + assert result["skipped"] == 1 + assert result["total_prompt_tokens"] == 0 + + +def test_warm_prefix_cache_all_prompts_fail_same_way(): + """All N prompts error identically → skipped=N, count=0, no crash.""" + engine = _StubEngine() + + # Force every call to raise by setting raise_on=0 and iterating manually + class _AlwaysFail(_StubEngine): + async def stream_chat(self, *, messages, max_tokens, temperature, **_): + self.calls.append(messages) + raise RuntimeError("boom") + yield # pragma: no cover — makes it an async generator + + engine = _AlwaysFail() + prompts = [AGENT_CODING_MESSAGES_1] * 3 + result = asyncio.run(warm_prefix_cache(engine, prompts)) + assert result["count"] == 0 + assert result["skipped"] == 3 + + +def test_warm_prefix_cache_unicode_in_prompts(): + """Unicode content passes through unmodified.""" + engine = _StubEngine(prompt_tokens_per_call=100) + prompts = [[{"role": "system", "content": "Héllo 世界 🌍"}]] + result = asyncio.run(warm_prefix_cache(engine, prompts)) + assert result["count"] == 1 + assert engine.calls[0][0]["content"] == "Héllo 世界 🌍" + + +def test_warm_prefix_cache_missing_prompt_tokens(): + """Engine returns None for prompt_tokens → counted as 0, no crash.""" + + class _NoTokenCount(_StubEngine): + async def stream_chat(self, *, messages, max_tokens, temperature, **_): + self.calls.append(messages) + yield _FakeOutput("x", finished=True, prompt_tokens=None) + + engine = _NoTokenCount() + result = asyncio.run(warm_prefix_cache(engine, [AGENT_CODING_MESSAGES_1])) + assert result["count"] == 1 + assert result["total_prompt_tokens"] == 0 + + +# --------------------------------------------------------------------------- +# _build_strict_prefix_string — the strict-prefix warmup builder +# --------------------------------------------------------------------------- + + +class _FakeTokenizer: + """Minimal tokenizer that mimics apply_chat_template returning a string.""" + + def __init__(self, render): + self._render = render + + def apply_chat_template(self, messages, **kwargs): + return self._render(messages, kwargs) + + +def test_build_strict_prefix_basic(): + """Qwen-style template: two user probes diverge at user content.""" + + def render(msgs, kwargs): + sys_content = msgs[0]["content"] + user_content = msgs[-1]["content"] + return ( + f"<|im_start|>system\n{sys_content}<|im_end|>\n" + f"<|im_start|>user\n{user_content}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + tok = _FakeTokenizer(render) + msgs = [{"role": "system", "content": "You are helpful."}] + prefix = _build_strict_prefix_string(tok, msgs) + assert prefix is not None + # Must end right before the user content insertion point. + assert prefix.endswith("<|im_start|>user\n") + # System content preserved. + assert "You are helpful." in prefix + + +def test_build_strict_prefix_none_if_tokenizer_lacks_method(): + """Tokenizer without apply_chat_template → None.""" + + class _Bare: + pass + + assert ( + _build_strict_prefix_string(_Bare(), [{"role": "system", "content": "x"}]) + is None + ) + + +def test_build_strict_prefix_none_on_template_error(): + """apply_chat_template raising returns None (and doesn't propagate).""" + + def render(msgs, kwargs): + raise RuntimeError("template broken") + + tok = _FakeTokenizer(render) + assert ( + _build_strict_prefix_string(tok, [{"role": "system", "content": "x"}]) is None + ) + + +def test_build_strict_prefix_none_if_identical_probes(): + """If the template ignores user content (both probes identical), there's + no divergence → can't produce a strict prefix → return None.""" + + def render(msgs, kwargs): + return "<|im_start|>system\nSYS<|im_end|><|im_start|>assistant\n" + + tok = _FakeTokenizer(render) + assert ( + _build_strict_prefix_string(tok, [{"role": "system", "content": "SYS"}]) is None + ) + + +def test_build_strict_prefix_retries_without_enable_thinking(): + """If the template rejects enable_thinking (non-Qwen models), retry without.""" + calls = [] + + def render(msgs, kwargs): + calls.append(kwargs.copy()) + if "enable_thinking" in kwargs: + raise TypeError("unexpected keyword argument 'enable_thinking'") + sys = msgs[0]["content"] + user = msgs[-1]["content"] + return f"<|begin_of_text|>system\n{sys}\n<|end|>user\n{user}\n<|end|>" + + tok = _FakeTokenizer(render) + prefix = _build_strict_prefix_string(tok, [{"role": "system", "content": "SYS"}]) + assert prefix is not None + # Called at least twice: first with enable_thinking (failed), then without + assert any("enable_thinking" in c for c in calls) + assert any("enable_thinking" not in c for c in calls) + + +def test_warm_prefix_cache_runs_prompts_concurrently(): + """With N prompts, elapsed_ms should be closer to time-of-one than N * time-of-one. + + We simulate engine latency by sleeping inside the stub. If the warmer + runs sequentially, total time is N * sleep. If concurrent, ~1 * sleep. + """ + import asyncio as _asyncio + + SLEEP_PER_CALL = 0.05 # 50 ms + + class _SlowEngine(_StubEngine): + async def stream_chat(self, *, messages, max_tokens, temperature, **_): + self.calls.append(messages) + await _asyncio.sleep(SLEEP_PER_CALL) + yield _FakeOutput("x", finished=True, prompt_tokens=100) + + N = 5 + engine = _SlowEngine(prompt_tokens_per_call=100) + prompts = [AGENT_CODING_MESSAGES_1] * N + + result = asyncio.run(warm_prefix_cache(engine, prompts)) + assert result["count"] == N + + # If serial, total >= N * SLEEP_PER_CALL * 1000 = 250 ms. + # If parallel via asyncio.gather, total is ~SLEEP_PER_CALL * 1000 = 50-100 ms. + # Allow generous slack; we just need to prove it's NOT serial. + assert result["elapsed_ms"] < (N * SLEEP_PER_CALL * 1000 * 0.6), ( + f"warm_prefix_cache appears serial: {result['elapsed_ms']:.1f} ms for " + f"{N} * {SLEEP_PER_CALL*1000:.0f} ms sleep" + ) diff --git a/tests/test_qwen35_mllm_patch.py b/tests/test_qwen35_mllm_patch.py new file mode 100644 index 000000000..ce230ba4f --- /dev/null +++ b/tests/test_qwen35_mllm_patch.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Regression tests for the Qwen3.5 MLLM attention patch.""" + +import sys +import types +from typing import Any, cast + +import mlx.core as mx + +from vllm_mlx.patches.qwen3_5_mllm import patch_qwen35_attention_for_batching + + +def _position_ids(seq_len: int) -> mx.array: + base = mx.arange(seq_len).reshape(1, 1, seq_len) + return mx.tile(base, (3, 1, 1)) + + +def _install_fake_qwen35_modules(monkeypatch): + call_log: list[dict[str, int]] = [] + + class _IdentityNorm: + def __call__(self, x: mx.array) -> mx.array: + return x + + class DummyCache: + def __init__(self, offset=0): + self.offset = offset + + def update_and_fetch(self, keys: mx.array, values: mx.array): + return keys, values + + class DummyAttention: + def __init__(self): + self.num_attention_heads = 16 + self.num_key_value_heads = 16 + self.head_dim = 64 + self.scale = 1.0 + + self.q_proj = lambda x: mx.zeros( + (x.shape[0], x.shape[1], self.num_attention_heads * self.head_dim * 2), + dtype=x.dtype, + ) + self.k_proj = lambda x: mx.zeros( + (x.shape[0], x.shape[1], self.num_key_value_heads * self.head_dim), + dtype=x.dtype, + ) + self.v_proj = lambda x: mx.zeros( + (x.shape[0], x.shape[1], self.num_key_value_heads * self.head_dim), + dtype=x.dtype, + ) + self.q_norm = _IdentityNorm() + self.k_norm = _IdentityNorm() + self.o_proj = lambda x: x + + def _rotary_emb(values: mx.array, position_ids: mx.array): + if position_ids.ndim == 3: + batch = int(position_ids.shape[1]) + else: + batch = int(position_ids.shape[0]) + seq_len = int(position_ids.shape[-1]) + head_dim = int(values.shape[-1]) + cos = mx.ones((batch, seq_len, head_dim), dtype=values.dtype) + sin = mx.zeros((batch, seq_len, head_dim), dtype=values.dtype) + return cos, sin + + self.rotary_emb = _rotary_emb + + def fake_apply_multimodal_rotary_pos_emb( + queries: mx.array, + keys: mx.array, + cos: mx.array, + sin: mx.array, + ): + del sin + cos_expanded = mx.expand_dims(cos, axis=1) + q_len = int(queries.shape[-2]) + cos_len = int(cos_expanded.shape[-2]) + call_log.append({"q_len": q_len, "cos_len": cos_len}) + + if q_len != cos_len: + raise ValueError( + f"[broadcast_shapes] Shapes {tuple(queries.shape)} and " + f"{tuple(cos_expanded.shape)} cannot be broadcast." + ) + + return queries, keys + + def fake_sdpa(queries: mx.array, keys: mx.array, values: mx.array, **kwargs): + del keys, values, kwargs + return queries + + fake_qwen35_mod = types.ModuleType("mlx_vlm.models.qwen3_5.language") + setattr(fake_qwen35_mod, "Qwen3_5Attention", DummyAttention) + setattr( + fake_qwen35_mod, + "apply_multimodal_rotary_pos_emb", + fake_apply_multimodal_rotary_pos_emb, + ) + + fake_lm_base_mod = types.ModuleType("mlx_lm.models.base") + setattr(fake_lm_base_mod, "scaled_dot_product_attention", fake_sdpa) + + monkeypatch.setitem(sys.modules, "mlx_vlm.models.qwen3_5.language", fake_qwen35_mod) + monkeypatch.setitem(sys.modules, "mlx_lm.models.base", fake_lm_base_mod) + + return DummyAttention, DummyCache, call_log + + +def test_qwen35_patch_generates_position_ids_when_missing(monkeypatch): + attention_cls, cache_cls, call_log = _install_fake_qwen35_modules(monkeypatch) + + assert patch_qwen35_attention_for_batching() is True + + attn = cast(Any, attention_cls()) + x = mx.zeros((1, 11, 1024), dtype=mx.float32) + cache = cache_cls(offset=0) + + out = attn(x, cache=cache, position_ids=None) + + assert out.shape[1] == 11 + assert call_log[-1] == {"q_len": 11, "cos_len": 11} + + +def test_qwen35_patch_accepts_matching_position_ids(monkeypatch): + attention_cls, cache_cls, call_log = _install_fake_qwen35_modules(monkeypatch) + + assert patch_qwen35_attention_for_batching() is True + + attn = cast(Any, attention_cls()) + x = mx.zeros((1, 28, 1024), dtype=mx.float32) + cache = cache_cls(offset=0) + + out = attn(x, cache=cache, position_ids=_position_ids(28)) + + assert out.shape[1] == 28 + assert call_log[-1] == {"q_len": 28, "cos_len": 28} + + +def test_qwen35_patch_recovers_from_stale_position_ids_between_requests(monkeypatch): + """Second call must not reuse shorter position_ids from a prior request.""" + attention_cls, cache_cls, call_log = _install_fake_qwen35_modules(monkeypatch) + + assert patch_qwen35_attention_for_batching() is True + + attn = cast(Any, attention_cls()) + cache = cache_cls(offset=0) + + short_position_ids = _position_ids(11) + + first_x = mx.zeros((1, 11, 1024), dtype=mx.float32) + attn(first_x, cache=cache, position_ids=short_position_ids) + + second_x = mx.zeros((1, 28, 1024), dtype=mx.float32) + out = attn(second_x, cache=cache, position_ids=short_position_ids) + + assert out.shape[1] == 28 + assert call_log[-1] == {"q_len": 28, "cos_len": 28} diff --git a/tests/test_qwen35_mtp_patch.py b/tests/test_qwen35_mtp_patch.py new file mode 100644 index 000000000..74dd4ff06 --- /dev/null +++ b/tests/test_qwen35_mtp_patch.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Regression tests for Qwen3.5/3.6 MTP weight fixups.""" + +import mlx.core as mx + +from vllm_mlx.patches.qwen3_5_mtp import ( + _apply_qwen_mtp_rmsnorm_offset_fixups, +) + + +def test_qwen_mtp_raw_offset_norm_weights_shift_once(): + weights = { + "pre_fc_norm_hidden.weight": mx.array([-0.4, 0.0], dtype=mx.float32), + "layers.0.input_layernorm.weight": mx.array([-0.2, 0.1], dtype=mx.float32), + } + + shifted = _apply_qwen_mtp_rmsnorm_offset_fixups(weights) + + assert shifted == 2 + assert mx.allclose( + weights["pre_fc_norm_hidden.weight"], + mx.array([0.6, 1.0], dtype=mx.float32), + ) + assert mx.allclose( + weights["layers.0.input_layernorm.weight"], + mx.array([0.8, 1.1], dtype=mx.float32), + ) + + +def test_qwen_mtp_actual_gamma_norm_weights_are_not_shifted_again(): + original = mx.array([0.56, 0.82], dtype=mx.float32) + weights = {"pre_fc_norm_embedding.weight": original} + + shifted = _apply_qwen_mtp_rmsnorm_offset_fixups(weights) + + assert shifted == 0 + assert mx.allclose(weights["pre_fc_norm_embedding.weight"], original) + + +def test_qwen_mtp_non_norm_one_dimensional_weights_are_not_shifted(): + original = mx.array([-0.4, 0.0], dtype=mx.float32) + weights = {"layers.0.mlp.shared_expert_gate.weight": original} + + shifted = _apply_qwen_mtp_rmsnorm_offset_fixups(weights) + + assert shifted == 0 + assert mx.allclose(weights["layers.0.mlp.shared_expert_gate.weight"], original) diff --git a/tests/test_reasoning_parser.py b/tests/test_reasoning_parser.py index a73d77a45..5eb124386 100644 --- a/tests/test_reasoning_parser.py +++ b/tests/test_reasoning_parser.py @@ -295,6 +295,13 @@ def test_nested_tags_not_supported(self, parser): reasoning, content = parser.extract_reasoning(output) # Result may vary by parser implementation + def test_repeated_leading_think_blocks_do_not_leak_to_content(self, parser): + """Repeated leading think blocks are reasoning, not final content.""" + output = "reasoning\n\n\n\nanswer" + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "reasoning" + assert content == "answer" + def test_streaming_reset_state(self, parser): """reset_state should allow reuse of parser.""" # First stream @@ -668,6 +675,15 @@ def test_qwen3_empty_think_tags(self, parser): assert reasoning is None or reasoning.strip() == "" assert content == "Just the answer." + def test_qwen3_empty_think_tags_after_implicit_transition(self, parser): + """Empty blocks after an implicit end tag must not leak to content.""" + output = ( + "brief reasoning\n\n\n\n\n\nACK_THINK_READY" + ) + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "brief reasoning" + assert content == "ACK_THINK_READY" + def test_qwen3_whitespace_between_tags(self, parser): """Test various whitespace patterns.""" test_cases = [ @@ -682,6 +698,36 @@ def test_qwen3_whitespace_between_tags(self, parser): assert reasoning is None or reasoning.strip() == "" assert expected_content in (content or "") + def test_qwen3_streaming_empty_think_tags_after_transition(self, parser): + """Streaming parser suppresses repeated leading think blocks in content.""" + parser.reset_state() + deltas = [ + "brief reasoning", + "", + "", + "\n\n\n\n", + "ACK_THINK_READY", + ] + previous = "" + current = "" + reasoning_parts = [] + content_parts = [] + + for delta in deltas: + previous = current + current += delta + result = parser.extract_reasoning_streaming(previous, current, delta) + if result is None: + continue + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + assert "".join(reasoning_parts) == "brief reasoning" + assert "".join(content_parts) == "ACK_THINK_READY" + class TestGptOssParser: """Tests for the GPT-OSS reasoning parser (channel-based format).""" @@ -1318,3 +1364,166 @@ def test_triple_end_tag_stripped(self, parser): reasoning, content = parser.extract_reasoning(output) assert reasoning == "reasoning" assert content == "content" + + +class TestReasoningStrippedFromToolCallContent: + """ + Verify that extract_reasoning strips reasoning markers from text that + already had tool call markup removed (simulating server.py tool_calls path). + + When tool calls are present, server.py first runs the tool parser (which + strips tool markup but leaves reasoning markers), then runs + extract_reasoning on the cleaned_text to strip reasoning markers. + """ + + def test_gemma4_channel_tokens_stripped_after_tool_parse(self): + """Gemma 4 channel tokens should be stripped from post-tool-parse text.""" + parser = get_parser("gemma4")() + # Simulate what tool parser leaves behind: reasoning markers + residual text + # (tool markup like ```json...``` already removed by tool parser) + post_tool_text = "<|channel>thought\nLet me find the weather.\n" + reasoning, content = parser.extract_reasoning(post_tool_text) + assert reasoning == "Let me find the weather." + # content should be None or empty (tool calls carry the actual response) + assert content is None or content == "" + + def test_gemma4_alternative_format_stripped_after_tool_parse(self): + """Alternative <|channel>response format should also be stripped.""" + parser = get_parser("gemma4")() + post_tool_text = "<|channel>thought\nChecking parameters.\n<|channel>response\n" + reasoning, content = parser.extract_reasoning(post_tool_text) + assert reasoning == "Checking parameters." + assert content is None or content.strip() == "" + + def test_gemma4_empty_input_after_tool_parse(self): + """Empty string (all content was tool markup) should not crash.""" + parser = get_parser("gemma4")() + reasoning, content = parser.extract_reasoning("") + assert reasoning is None + + def test_qwen3_think_tags_stripped_after_tool_parse(self): + """Qwen3 tags should be stripped from post-tool-parse text.""" + parser = get_parser("qwen3")() + post_tool_text = "Let me call the function." + reasoning, content = parser.extract_reasoning(post_tool_text) + assert reasoning == "Let me call the function." + assert content is None or content == "" + + def test_deepseek_think_tags_stripped_after_tool_parse(self): + """DeepSeek-R1 tags should be stripped from post-tool-parse text.""" + parser = get_parser("deepseek_r1")() + post_tool_text = "I need to call this API." + reasoning, content = parser.extract_reasoning(post_tool_text) + assert reasoning == "I need to call this API." + assert content is None or content == "" + + def test_gemma4_residual_text_preserved(self): + """Non-reasoning text after channel markers should be preserved as content.""" + parser = get_parser("gemma4")() + post_tool_text = "<|channel>thought\nThinking...\nSome residual text" + reasoning, content = parser.extract_reasoning(post_tool_text) + assert reasoning == "Thinking..." + assert content == "Some residual text" + + +class TestGemma4DegenerateCycling: + """ + Test handling of degenerate thought/response cycling in Gemma 4. + + On long prompts with tools, Gemma 4 may oscillate between thought and + response channels many times before producing valid output. The parser + must split at the LAST so all cycles go into reasoning and + only the final response goes into content. + """ + + def test_multiple_channel_end_tokens_in_content(self): + """Multiple tokens: rpartition splits at the LAST one.""" + parser = get_parser("gemma4")() + output = ( + "<|channel>thought\nGarbage loop\nthought\nThe answer" + ) + reasoning, content = parser.extract_reasoning(output) + assert content == "The answer" + assert "Garbage loop" in reasoning + + def test_residual_thought_channel_stripped_from_content(self): + """thought\\n residuals must not leak into content.""" + parser = get_parser("gemma4")() + output = ( + "<|channel>thought\nLoop1\n" + "<|channel>thought\nLoop2\n" + "Final response here" + ) + reasoning, content = parser.extract_reasoning(output) + assert content == "Final response here" + assert "" not in (content or "") + assert "<|channel>" not in (content or "") + + def test_many_cycles_all_go_to_reasoning(self): + """All intermediate thought/response cycles end up in reasoning.""" + parser = get_parser("gemma4")() + output = ( + "<|channel>thought\nCycle1\n" + "<|channel>thought\nCycle2\n" + "<|channel>thought\nCycle3\n" + "Final content" + ) + reasoning, content = parser.extract_reasoning(output) + assert content == "Final content" + assert "Cycle1" in reasoning + assert "Cycle2" in reasoning + assert "Cycle3" in reasoning + + def test_channel_tokens_stripped_from_reasoning(self): + """Channel special tokens should not appear in reasoning output.""" + parser = get_parser("gemma4")() + output = ( + "<|channel>thought\nThink1\n" + "<|channel>thought\nThink2\n" + "Result" + ) + reasoning, content = parser.extract_reasoning(output) + assert "" not in (reasoning or "") + assert "<|channel>" not in (reasoning or "") + + def test_streaming_degenerate_cycling(self): + """Streaming: re-entry into thinking after content transition.""" + parser = get_parser("gemma4")() + deltas = [ + "<|channel>", + "thought\n", + "Thinking...", + "\n", + "\nSome content", + "\n<|channel>", + "thought\n", + "More thinking", + "\n", + "\nFinal answer", + ] + reasoning_parts = [] + content_parts = [] + accumulated = "" + for delta in deltas: + prev = accumulated + accumulated += delta + result = parser.extract_reasoning_streaming(prev, accumulated, delta) + if result is not None: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + final = parser.finalize_stream() + if final: + if final.reasoning: + reasoning_parts.append(final.reasoning) + if final.content: + content_parts.append(final.content) + + full_content = "".join(content_parts) + # Final answer should be in content (after last ) + assert "Final answer" in full_content + # Channel tokens must not leak + assert "" not in full_content + assert "<|channel>" not in full_content diff --git a/tests/test_rerank.py b/tests/test_rerank.py new file mode 100644 index 000000000..b110ef5a1 --- /dev/null +++ b/tests/test_rerank.py @@ -0,0 +1,872 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the /v1/rerank endpoint.""" + +import platform +import sys +from unittest.mock import MagicMock + +import pytest + +# Skip all tests if not on Apple Silicon +pytestmark = pytest.mark.skipif( + sys.platform != "darwin" or platform.machine() != "arm64", + reason="Requires Apple Silicon", +) + + +# ============================================================================= +# Unit Tests - Pydantic Models +# ============================================================================= + + +class TestRerankModels: + """Test rerank request/response Pydantic models.""" + + def test_rerank_request_with_string_documents(self): + """Test RerankRequest accepts a list of plain strings.""" + from vllm_mlx.api.models import RerankRequest + + req = RerankRequest( + model="jina-reranker-v2", + query="What is deep learning?", + documents=["doc one", "doc two"], + ) + assert req.model == "jina-reranker-v2" + assert req.query == "What is deep learning?" + assert req.documents == ["doc one", "doc two"] + assert req.top_n is None + assert req.return_documents is True + + def test_rerank_request_with_object_documents(self): + """Test RerankRequest accepts a list of {text: ...} objects.""" + from vllm_mlx.api.models import RerankRequest + + req = RerankRequest( + model="jina-reranker-v2", + query="query", + documents=[{"text": "alpha"}, {"text": "beta"}], + ) + assert req.documents == [{"text": "alpha"}, {"text": "beta"}] + + def test_rerank_request_top_n_default_none(self): + """Test that top_n defaults to None (return all).""" + from vllm_mlx.api.models import RerankRequest + + req = RerankRequest(model="m", query="q", documents=["a"]) + assert req.top_n is None + + def test_rerank_result_serialization(self): + """Test RerankResult serializes with correct fields.""" + from vllm_mlx.api.models import RerankResult + + result = RerankResult( + index=2, + relevance_score=0.95, + document={"text": "hello"}, + ) + d = result.model_dump() + assert d["index"] == 2 + assert d["relevance_score"] == 0.95 + assert d["document"] == {"text": "hello"} + + def test_rerank_result_without_document(self): + """Test RerankResult when document is omitted.""" + from vllm_mlx.api.models import RerankResult + + result = RerankResult(index=0, relevance_score=0.5) + d = result.model_dump() + assert d["index"] == 0 + assert d["document"] is None + + def test_rerank_response_serialization(self): + """Test RerankResponse serializes to Jina-compatible JSON.""" + from vllm_mlx.api.models import RerankResponse, RerankResult, RerankUsage + + response = RerankResponse( + model="jina-reranker-v2", + results=[ + RerankResult(index=1, relevance_score=0.9, document={"text": "best"}), + RerankResult(index=0, relevance_score=0.1, document={"text": "worst"}), + ], + usage=RerankUsage(total_tokens=42), + ) + d = response.model_dump() + assert d["model"] == "jina-reranker-v2" + assert len(d["results"]) == 2 + assert d["results"][0]["index"] == 1 + assert d["results"][0]["relevance_score"] == 0.9 + assert d["usage"]["total_tokens"] == 42 + + +# ============================================================================= +# Unit Tests - Reranker Adapter Contract +# ============================================================================= + + +class TestRerankAdapterContract: + """Test the base adapter interface and default sigmoid adapter.""" + + def test_base_adapter_is_abstract(self): + """Test that RerankAdapter cannot be instantiated directly.""" + from vllm_mlx.rerank import RerankAdapter + + with pytest.raises(TypeError): + RerankAdapter() + + def test_sigmoid_adapter_normalize_maps_to_zero_one(self): + """Test that SigmoidAdapter.normalize applies sigmoid correctly.""" + from vllm_mlx.rerank import SigmoidAdapter + + adapter = SigmoidAdapter() + # sigmoid(0) = 0.5 + assert abs(adapter.normalize(0.0) - 0.5) < 1e-6 + # sigmoid(large positive) -> ~1.0 + assert adapter.normalize(10.0) > 0.999 + # sigmoid(large negative) -> ~0.0 + assert adapter.normalize(-10.0) < 0.001 + + def test_sigmoid_adapter_extract_score_takes_first_logit(self): + """Test that SigmoidAdapter.extract_score returns logits[0].""" + from vllm_mlx.rerank import SigmoidAdapter + + adapter = SigmoidAdapter() + # Simulate a logits array with shape (num_labels,) + logits = [2.5, -1.0, 0.3] + assert adapter.extract_score(logits) == 2.5 + + def test_sigmoid_adapter_tokenize_pair_returns_dict(self): + """Test that SigmoidAdapter.tokenize_pair produces a token dict.""" + from unittest.mock import MagicMock + + from vllm_mlx.rerank import SigmoidAdapter + + adapter = SigmoidAdapter() + mock_tokenizer = MagicMock() + mock_tokenizer.return_value = { + "input_ids": [[101, 2054, 102, 3793, 102]], + "attention_mask": [[1, 1, 1, 1, 1]], + } + result = adapter.tokenize_pair(mock_tokenizer, "query", "document") + mock_tokenizer.assert_called_once_with( + "query", + "document", + padding=True, + truncation=True, + max_length=512, + return_tensors="np", + ) + assert "input_ids" in result + assert "attention_mask" in result + + +# ============================================================================= +# Unit Tests - RerankEngine +# ============================================================================= + + +class TestRerankEngine: + """Test the RerankEngine model loading and scoring.""" + + def test_engine_not_loaded_initially(self): + """Test that a new engine reports is_loaded=False.""" + from vllm_mlx.rerank import RerankEngine + + engine = RerankEngine("test-model") + assert engine.is_loaded is False + + def test_engine_model_name_stored(self): + """Test that model_name is stored on construction.""" + from vllm_mlx.rerank import RerankEngine + + engine = RerankEngine("mlx-community/jina-reranker-v2-base-multilingual") + assert engine.model_name == "mlx-community/jina-reranker-v2-base-multilingual" + + def test_score_pairs_returns_normalized_scores(self): + """Test score_pairs returns sigmoid-normalized scores for each pair.""" + import math + + import numpy as np + + from vllm_mlx.rerank import RerankEngine, SigmoidAdapter + + engine = RerankEngine("test-model") + + # Mock model: returns logits with shape (batch, num_labels) + mock_model = MagicMock() + # Simulate two pairs scored: logits [2.0, ...] and [-1.0, ...] + mock_logits = MagicMock() + mock_logits.tolist.return_value = [[2.0, 0.5], [-1.0, 0.3]] + mock_model.return_value = MagicMock(logits=mock_logits) + + mock_tokenizer = MagicMock() + mock_tokenizer.return_value = { + "input_ids": np.array([[1, 2, 3], [4, 5, 6]]), + "attention_mask": np.array([[1, 1, 1], [1, 1, 1]]), + } + + engine._model = mock_model + engine._tokenizer = mock_tokenizer + engine._adapter = SigmoidAdapter() + + scores, total_tokens = engine.score_pairs("test query", ["doc one", "doc two"]) + assert len(scores) == 2 + assert total_tokens == 6 # 3 tokens per pair + expected_0 = 1.0 / (1.0 + math.exp(-2.0)) + expected_1 = 1.0 / (1.0 + math.exp(1.0)) + assert abs(scores[0] - expected_0) < 1e-6 + assert abs(scores[1] - expected_1) < 1e-6 + + def test_score_pairs_token_budget_batching(self): + """Test that score_pairs splits work into batches by token budget.""" + import numpy as np + + from vllm_mlx.rerank import RerankEngine, SigmoidAdapter + + engine = RerankEngine("test-model", token_budget=10) + + mock_model = MagicMock() + mock_tokenizer = MagicMock() + + # Each call returns 1 pair. We have 3 documents. + # With token_budget=10 and each pair using 5 tokens, we get batches of 2 then 1. + call_count = 0 + + def mock_tokenize(query, doc, **kwargs): + nonlocal call_count + call_count += 1 + return { + "input_ids": np.array([[1, 2, 3, 4, 5]]), + "attention_mask": np.array([[1, 1, 1, 1, 1]]), + } + + mock_tokenizer.side_effect = mock_tokenize + + def mock_forward(input_ids, **kwargs): + # Return logits matching the batch dimension of input_ids + batch_size = input_ids.shape[0] + mock_logits = MagicMock() + mock_logits.tolist.return_value = [[0.5]] * batch_size + return MagicMock(logits=mock_logits) + + mock_model.side_effect = mock_forward + + engine._model = mock_model + engine._tokenizer = mock_tokenizer + engine._adapter = SigmoidAdapter() + + scores, total_tokens = engine.score_pairs("q", ["d1", "d2", "d3"]) + assert len(scores) == 3 + assert total_tokens == 15 # 3 docs * 5 tokens each + # tokenizer called once per document (for budget estimation + scoring) + assert call_count == 3 + + def test_score_pairs_returns_token_count(self): + """Test that score_pairs returns total_tokens consistent with scoring.""" + import numpy as np + + from vllm_mlx.rerank import RerankEngine, SigmoidAdapter + + engine = RerankEngine("test-model") + + mock_model = MagicMock() + mock_logits = MagicMock() + mock_logits.tolist.return_value = [[0.5], [0.3]] + mock_model.return_value = MagicMock(logits=mock_logits) + + mock_tokenizer = MagicMock() + # Each pair tokenizes to 7 tokens + mock_tokenizer.return_value = { + "input_ids": np.array([[1, 2, 3, 4, 5, 6, 7]]), + "attention_mask": np.array([[1, 1, 1, 1, 1, 1, 1]]), + } + + engine._model = mock_model + engine._tokenizer = mock_tokenizer + engine._adapter = SigmoidAdapter() + + scores, total_tokens = engine.score_pairs("query", ["doc1", "doc2"]) + assert len(scores) == 2 + # 2 docs * 7 tokens each = 14 + assert total_tokens == 14 + + +# ============================================================================= +# Unit Tests - Classifier Forward Pass +# ============================================================================= + + +class TestClassifierForward: + """Test the MLX classifier forward pass for cross-encoder models.""" + + def test_classifier_forward_returns_logits_shape(self): + """Test that classifier_forward returns logits with (batch, num_labels) shape.""" + import mlx.core as mx + + from vllm_mlx.rerank_forward import classifier_forward + + # Build minimal BERT-like weights + vocab_size = 30 + hidden_size = 16 + num_heads = 2 + intermediate_size = 32 + num_labels = 1 + + config = { + "hidden_size": hidden_size, + "num_attention_heads": num_heads, + "intermediate_size": intermediate_size, + "num_hidden_layers": 1, + "num_labels": num_labels, + "vocab_size": vocab_size, + "max_position_embeddings": 64, + "type_vocab_size": 2, + "layer_norm_eps": 1e-12, + "hidden_act": "gelu", + } + + weights = _make_bert_weights(config) + + input_ids = mx.array([[1, 2, 3, 0, 0], [4, 5, 6, 7, 0]]) + attention_mask = mx.array([[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]]) + + logits = classifier_forward(input_ids, attention_mask, weights, config) + mx.eval(logits) + + assert logits.shape == (2, num_labels) + + def test_classifier_forward_different_num_labels(self): + """Test forward pass with num_labels=3 (multi-class).""" + import mlx.core as mx + + from vllm_mlx.rerank_forward import classifier_forward + + hidden_size = 8 + config = { + "hidden_size": hidden_size, + "num_attention_heads": 2, + "intermediate_size": 16, + "num_hidden_layers": 1, + "num_labels": 3, + "vocab_size": 20, + "max_position_embeddings": 32, + "type_vocab_size": 2, + "layer_norm_eps": 1e-12, + "hidden_act": "gelu", + } + + weights = _make_bert_weights(config) + + input_ids = mx.array([[1, 2, 3]]) + attention_mask = mx.array([[1, 1, 1]]) + + logits = classifier_forward(input_ids, attention_mask, weights, config) + mx.eval(logits) + + assert logits.shape == (1, 3) + + +def _make_bert_weights(config: dict) -> dict: + """Build minimal random BERT-style weights for testing.""" + import mlx.core as mx + + h = config["hidden_size"] + inter = config["intermediate_size"] + vocab = config["vocab_size"] + max_pos = config["max_position_embeddings"] + type_vocab = config["type_vocab_size"] + num_labels = config["num_labels"] + n_layers = config["num_hidden_layers"] + + w = {} + # Embeddings + w["bert.embeddings.word_embeddings.weight"] = mx.random.normal((vocab, h)) * 0.02 + w["bert.embeddings.position_embeddings.weight"] = ( + mx.random.normal((max_pos, h)) * 0.02 + ) + w["bert.embeddings.token_type_embeddings.weight"] = ( + mx.random.normal((type_vocab, h)) * 0.02 + ) + w["bert.embeddings.LayerNorm.weight"] = mx.ones((h,)) + w["bert.embeddings.LayerNorm.bias"] = mx.zeros((h,)) + + for i in range(n_layers): + prefix = f"bert.encoder.layer.{i}" + # Self-attention + for proj in ["query", "key", "value"]: + w[f"{prefix}.attention.self.{proj}.weight"] = ( + mx.random.normal((h, h)) * 0.02 + ) + w[f"{prefix}.attention.self.{proj}.bias"] = mx.zeros((h,)) + w[f"{prefix}.attention.output.dense.weight"] = mx.random.normal((h, h)) * 0.02 + w[f"{prefix}.attention.output.dense.bias"] = mx.zeros((h,)) + w[f"{prefix}.attention.output.LayerNorm.weight"] = mx.ones((h,)) + w[f"{prefix}.attention.output.LayerNorm.bias"] = mx.zeros((h,)) + # FFN + w[f"{prefix}.intermediate.dense.weight"] = mx.random.normal((inter, h)) * 0.02 + w[f"{prefix}.intermediate.dense.bias"] = mx.zeros((inter,)) + w[f"{prefix}.output.dense.weight"] = mx.random.normal((h, inter)) * 0.02 + w[f"{prefix}.output.dense.bias"] = mx.zeros((h,)) + w[f"{prefix}.output.LayerNorm.weight"] = mx.ones((h,)) + w[f"{prefix}.output.LayerNorm.bias"] = mx.zeros((h,)) + + # Pooler + w["bert.pooler.dense.weight"] = mx.random.normal((h, h)) * 0.02 + w["bert.pooler.dense.bias"] = mx.zeros((h,)) + + # Classifier + w["classifier.weight"] = mx.random.normal((num_labels, h)) * 0.02 + w["classifier.bias"] = mx.zeros((num_labels,)) + + return w + + +# ============================================================================= +# Integration Tests - FastAPI Endpoint +# ============================================================================= + + +class TestRerankEndpoint: + """Test the /v1/rerank endpoint via TestClient.""" + + @pytest.fixture() + def client(self): + """Create a FastAPI test client.""" + from fastapi.testclient import TestClient + + from vllm_mlx.server import app + + return TestClient(app) + + def test_rerank_returns_sorted_results(self, client): + """Test that /v1/rerank returns results sorted by relevance_score descending.""" + import vllm_mlx.server as srv + + mock_engine = MagicMock() + mock_engine.model_name = "test-reranker" + # doc 0 gets low score, doc 1 gets high score, doc 2 gets mid score + mock_engine.score_pairs.return_value = ([0.1, 0.9, 0.5], 30) + + original = srv._rerank_engine + srv._rerank_engine = mock_engine + try: + resp = client.post( + "/v1/rerank", + json={ + "model": "test-reranker", + "query": "What is deep learning?", + "documents": ["bad match", "great match", "ok match"], + }, + ) + finally: + srv._rerank_engine = original + + assert resp.status_code == 200 + body = resp.json() + assert len(body["results"]) == 3 + # Sorted descending by score + assert body["results"][0]["index"] == 1 + assert body["results"][0]["relevance_score"] == 0.9 + assert body["results"][1]["index"] == 2 + assert body["results"][1]["relevance_score"] == 0.5 + assert body["results"][2]["index"] == 0 + assert body["results"][2]["relevance_score"] == 0.1 + + def test_rerank_top_n(self, client): + """Test that top_n limits returned results.""" + import vllm_mlx.server as srv + + mock_engine = MagicMock() + mock_engine.model_name = "test-reranker" + mock_engine.score_pairs.return_value = ([0.1, 0.9, 0.5], 20) + + original = srv._rerank_engine + srv._rerank_engine = mock_engine + try: + resp = client.post( + "/v1/rerank", + json={ + "model": "test-reranker", + "query": "query", + "documents": ["a", "b", "c"], + "top_n": 2, + }, + ) + finally: + srv._rerank_engine = original + + assert resp.status_code == 200 + body = resp.json() + assert len(body["results"]) == 2 + assert body["results"][0]["relevance_score"] == 0.9 + assert body["results"][1]["relevance_score"] == 0.5 + + def test_rerank_top_n_exceeds_documents_returns_400(self, client): + """Test that top_n > len(documents) returns 400.""" + import vllm_mlx.server as srv + + mock_engine = MagicMock() + mock_engine.model_name = "test-reranker" + + original = srv._rerank_engine + srv._rerank_engine = mock_engine + try: + resp = client.post( + "/v1/rerank", + json={ + "model": "test-reranker", + "query": "query", + "documents": ["a", "b"], + "top_n": 5, + }, + ) + finally: + srv._rerank_engine = original + + assert resp.status_code == 400 + assert "top_n" in resp.json()["detail"] + + def test_rerank_return_documents_false(self, client): + """Test that return_documents=false omits document text.""" + import vllm_mlx.server as srv + + mock_engine = MagicMock() + mock_engine.model_name = "test-reranker" + mock_engine.score_pairs.return_value = ([0.7], 5) + + original = srv._rerank_engine + srv._rerank_engine = mock_engine + try: + resp = client.post( + "/v1/rerank", + json={ + "model": "test-reranker", + "query": "q", + "documents": ["d"], + "return_documents": False, + }, + ) + finally: + srv._rerank_engine = original + + assert resp.status_code == 200 + body = resp.json() + assert body["results"][0]["document"] is None + + def test_rerank_preserves_object_documents(self, client): + """Test that object documents preserve original structure.""" + import vllm_mlx.server as srv + + mock_engine = MagicMock() + mock_engine.model_name = "test-reranker" + mock_engine.score_pairs.return_value = ([0.8, 0.3], 10) + + original = srv._rerank_engine + srv._rerank_engine = mock_engine + try: + resp = client.post( + "/v1/rerank", + json={ + "model": "test-reranker", + "query": "q", + "documents": [ + {"text": "alpha", "metadata": "extra1"}, + {"text": "beta", "metadata": "extra2"}, + ], + }, + ) + finally: + srv._rerank_engine = original + + assert resp.status_code == 200 + body = resp.json() + # Results sorted by score descending; index 0 (0.8) first + assert body["results"][0]["index"] == 0 + assert body["results"][0]["document"]["text"] == "alpha" + assert body["results"][0]["document"]["metadata"] == "extra1" + + def test_rerank_empty_documents_returns_400(self, client): + """Test that empty document list returns 400.""" + import vllm_mlx.server as srv + + mock_engine = MagicMock() + mock_engine.model_name = "test-reranker" + + original = srv._rerank_engine + srv._rerank_engine = mock_engine + try: + resp = client.post( + "/v1/rerank", + json={ + "model": "test-reranker", + "query": "q", + "documents": [], + }, + ) + finally: + srv._rerank_engine = original + + assert resp.status_code == 400 + + def test_rerank_usage_tokens(self, client): + """Test that usage.total_tokens is included in response.""" + import vllm_mlx.server as srv + + mock_engine = MagicMock() + mock_engine.model_name = "test-reranker" + mock_engine.score_pairs.return_value = ([0.5], 42) + + original = srv._rerank_engine + srv._rerank_engine = mock_engine + try: + resp = client.post( + "/v1/rerank", + json={ + "model": "test-reranker", + "query": "q", + "documents": ["d"], + }, + ) + finally: + srv._rerank_engine = original + + assert resp.status_code == 200 + assert resp.json()["usage"]["total_tokens"] == 42 + + def test_rerank_model_locked_rejects_different_model(self, client): + """Test that a locked reranker model rejects requests for different models.""" + import vllm_mlx.server as srv + + mock_engine = MagicMock() + mock_engine.model_name = "locked-reranker" + + original_engine = srv._rerank_engine + original_locked = srv._rerank_model_locked + srv._rerank_engine = mock_engine + srv._rerank_model_locked = "locked-reranker" + + try: + resp = client.post( + "/v1/rerank", + json={ + "model": "other-reranker", + "query": "q", + "documents": ["d"], + }, + ) + assert resp.status_code == 400 + body = resp.json() + assert "locked-reranker" in body["detail"] + assert "other-reranker" in body["detail"] + finally: + srv._rerank_engine = original_engine + srv._rerank_model_locked = original_locked + + def test_rerank_no_engine_returns_503(self, client): + """Test that requesting rerank without a loaded engine returns 404.""" + import vllm_mlx.server as srv + + original = srv._rerank_engine + srv._rerank_engine = None + try: + resp = client.post( + "/v1/rerank", + json={ + "model": "any-model", + "query": "q", + "documents": ["d"], + }, + ) + finally: + srv._rerank_engine = original + + assert resp.status_code == 404 + assert "--rerank-model" in resp.json()["detail"] + + def test_rerank_empty_query_returns_400(self, client): + """Test that an empty query returns 400.""" + import vllm_mlx.server as srv + + mock_engine = MagicMock() + mock_engine.model_name = "test-reranker" + + original = srv._rerank_engine + srv._rerank_engine = mock_engine + try: + resp = client.post( + "/v1/rerank", + json={ + "model": "test-reranker", + "query": "", + "documents": ["d"], + }, + ) + finally: + srv._rerank_engine = original + + assert resp.status_code == 400 + assert "Query" in resp.json()["detail"] + + +# ============================================================================= +# Unit Tests - CLI Argument +# ============================================================================= + + +class TestRerankCLI: + """Test that --rerank-model CLI arg is registered.""" + + def test_rerank_model_arg_exists(self): + """Test that the serve subcommand accepts --rerank-model.""" + from vllm_mlx.cli import build_parser + + parser = build_parser() + # Parse with --rerank-model; should not raise + args = parser.parse_args( + [ + "serve", + "test-model", + "--rerank-model", + "mlx-community/jina-reranker-v2-base-multilingual", + ] + ) + assert args.rerank_model == "mlx-community/jina-reranker-v2-base-multilingual" + + def test_rerank_model_arg_default_none(self): + """Test that --rerank-model defaults to None.""" + from vllm_mlx.cli import build_parser + + parser = build_parser() + args = parser.parse_args(["serve", "test-model"]) + assert args.rerank_model is None + + +# ============================================================================= +# Integration Tests - Full Round-Trip +# ============================================================================= + + +class TestRerankIntegration: + """Full round-trip tests exercising the endpoint with mocked RerankEngine.""" + + @pytest.fixture() + def client(self): + from fastapi.testclient import TestClient + + from vllm_mlx.server import app + + return TestClient(app) + + def test_full_roundtrip_string_documents(self, client): + """Test full request/response cycle with string documents.""" + import vllm_mlx.server as srv + + mock_engine = MagicMock() + mock_engine.model_name = "test-reranker" + mock_engine.score_pairs.return_value = ([0.2, 0.8, 0.5], 45) + + original = srv._rerank_engine + srv._rerank_engine = mock_engine + try: + resp = client.post( + "/v1/rerank", + json={ + "model": "test-reranker", + "query": "machine learning", + "documents": [ + "Cooking recipes for beginners", + "Introduction to neural networks", + "History of computer science", + ], + "top_n": 2, + "return_documents": True, + }, + ) + finally: + srv._rerank_engine = original + + assert resp.status_code == 200 + body = resp.json() + + # Verify structure + assert "model" in body + assert body["model"] == "test-reranker" + assert "results" in body + assert "usage" in body + assert body["usage"]["total_tokens"] == 45 + + # Verify top_n + assert len(body["results"]) == 2 + + # Verify sorting (descending by score) + assert ( + body["results"][0]["relevance_score"] + >= body["results"][1]["relevance_score"] + ) + + # Verify index preservation + assert body["results"][0]["index"] == 1 # "Introduction to neural networks" + assert body["results"][1]["index"] == 2 # "History of computer science" + + # Verify documents included + assert ( + body["results"][0]["document"]["text"] == "Introduction to neural networks" + ) + assert body["results"][1]["document"]["text"] == "History of computer science" + + def test_full_roundtrip_object_documents(self, client): + """Test full request/response cycle with object documents preserving metadata.""" + import vllm_mlx.server as srv + + mock_engine = MagicMock() + mock_engine.model_name = "test-reranker" + mock_engine.score_pairs.return_value = ([0.9, 0.1], 20) + + original = srv._rerank_engine + srv._rerank_engine = mock_engine + try: + resp = client.post( + "/v1/rerank", + json={ + "model": "test-reranker", + "query": "q", + "documents": [ + {"text": "relevant doc", "id": "doc-001", "source": "arxiv"}, + {"text": "irrelevant doc", "id": "doc-002", "source": "wiki"}, + ], + }, + ) + finally: + srv._rerank_engine = original + + assert resp.status_code == 200 + body = resp.json() + assert len(body["results"]) == 2 + + # First result should be index 0 (score 0.9) + top = body["results"][0] + assert top["index"] == 0 + assert top["document"]["text"] == "relevant doc" + assert top["document"]["id"] == "doc-001" + assert top["document"]["source"] == "arxiv" + + def test_rerank_model_appears_in_v1_models(self, client): + """Test that a loaded reranker model appears in /v1/models.""" + import vllm_mlx.server as srv + + mock_engine = MagicMock() + mock_engine.model_name = "test-reranker" + + original = srv._rerank_engine + srv._rerank_engine = mock_engine + try: + resp = client.get("/v1/models") + finally: + srv._rerank_engine = original + + assert resp.status_code == 200 + body = resp.json() + reranker_models = [ + m for m in body["data"] if m.get("owned_by") == "vllm-mlx-reranker" + ] + assert len(reranker_models) == 1 + assert reranker_models[0]["id"] == "test-reranker" diff --git a/tests/test_responses_api.py b/tests/test_responses_api.py new file mode 100644 index 000000000..78f8c4c20 --- /dev/null +++ b/tests/test_responses_api.py @@ -0,0 +1,767 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the OpenAI-compatible Responses API.""" + +import json +import platform +import sys +from collections import OrderedDict +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi.testclient import TestClient + +pytestmark = pytest.mark.skipif( + sys.platform != "darwin" or platform.machine() != "arm64", + reason="Requires Apple Silicon", +) + + +@pytest.fixture() +def client(): + from vllm_mlx.server import app + + return TestClient(app) + + +@pytest.fixture(autouse=True) +def server_state(): + import vllm_mlx.server as srv + + original_engine = srv._engine + original_model_name = srv._model_name + original_store = srv._responses_store + original_store_max_size = srv._RESPONSES_STORE_MAX_SIZE + original_api_key = srv._api_key + original_default_chat_template_kwargs = getattr( + srv, "_default_chat_template_kwargs", None + ) + + srv._engine = None + srv._model_name = "test-model" + srv._responses_store = OrderedDict() + srv._RESPONSES_STORE_MAX_SIZE = 1000 + srv._api_key = None + srv._default_chat_template_kwargs = None + + try: + yield + finally: + srv._engine = original_engine + srv._model_name = original_model_name + srv._responses_store = original_store + srv._RESPONSES_STORE_MAX_SIZE = original_store_max_size + srv._api_key = original_api_key + srv._default_chat_template_kwargs = original_default_chat_template_kwargs + + +def _mock_engine(*outputs): + engine = MagicMock() + engine.model_name = "test-model" + engine.preserve_native_tool_format = False + engine.chat = AsyncMock(side_effect=list(outputs)) + stream_calls = [] + + async def _stream_chat(**kwargs): + stream_calls.append(kwargs) + for output in getattr(engine, "_stream_outputs", []): + yield output + + engine._stream_calls = stream_calls + engine._stream_outputs = [] + engine.stream_chat = _stream_chat + return engine + + +def _output(text: str, prompt_tokens: int = 7, completion_tokens: int = 3): + return SimpleNamespace( + text=text, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + finish_reason="stop", + ) + + +def _stream_output( + new_text: str, + prompt_tokens: int = 7, + completion_tokens: int = 1, + finish_reason: str | None = None, +): + return SimpleNamespace( + new_text=new_text, + text=new_text, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + finish_reason=finish_reason, + finished=finish_reason is not None, + ) + + +def _parse_sse_events(body: str) -> list[tuple[str, dict]]: + events = [] + for chunk in body.strip().split("\n\n"): + if not chunk.strip(): + continue + event_type = None + payload = None + for line in chunk.splitlines(): + if line.startswith("event: "): + event_type = line.removeprefix("event: ").strip() + elif line.startswith("data: "): + payload = json.loads(line.removeprefix("data: ").strip()) + if event_type is not None and payload is not None: + events.append((event_type, payload)) + return events + + +class TestResponsesEndpoint: + def test_basic_response(self, client): + import vllm_mlx.server as srv + + srv._engine = _mock_engine(_output("Hello there")) + + resp = client.post( + "/v1/responses", + json={"model": "test-model", "input": "Say hello"}, + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["object"] == "response" + assert body["output_text"] == "Hello there" + assert body["output"][0]["type"] == "message" + assert body["output"][0]["content"][0]["type"] == "output_text" + assert body["usage"]["input_tokens"] == 7 + assert body["usage"]["output_tokens"] == 3 + + def test_responses_applies_server_default_chat_template_kwargs(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Hello there")) + srv._engine = engine + srv._default_chat_template_kwargs = {"enable_thinking": False} + + resp = client.post( + "/v1/responses", + json={"model": "test-model", "input": "Say hello"}, + ) + + assert resp.status_code == 200 + assert engine.chat.call_args.kwargs["chat_template_kwargs"] == { + "enable_thinking": False + } + + def test_responses_request_kwargs_override_server_defaults(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Hello there")) + srv._engine = engine + srv._default_chat_template_kwargs = { + "enable_thinking": False, + "server_default_only": "yes", + } + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": "Say hello", + "chat_template_kwargs": { + "enable_thinking": True, + "request_only": 1, + }, + }, + ) + + assert resp.status_code == 200 + assert engine.chat.call_args.kwargs["chat_template_kwargs"] == { + "enable_thinking": True, + "server_default_only": "yes", + "request_only": 1, + } + + def test_previous_response_id_reuses_prior_context(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("First answer"), _output("Second answer")) + srv._engine = engine + + first = client.post( + "/v1/responses", + json={"model": "test-model", "input": "First prompt"}, + ) + first_id = first.json()["id"] + + second = client.post( + "/v1/responses", + json={ + "model": "test-model", + "previous_response_id": first_id, + "input": "Follow-up prompt", + }, + ) + + assert second.status_code == 200 + second_messages = engine.chat.call_args_list[1].kwargs["messages"] + assert second_messages[0]["role"] == "user" + assert second_messages[0]["content"] == "First prompt" + assert second_messages[1]["role"] == "assistant" + assert second_messages[1]["content"] == "First answer" + assert second_messages[2]["role"] == "user" + assert second_messages[2]["content"] == "Follow-up prompt" + + def test_previous_response_id_chains_across_multiple_follow_ups(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine( + _output("First answer"), + _output("Second answer"), + _output("Third answer"), + ) + srv._engine = engine + + first = client.post( + "/v1/responses", + json={"model": "test-model", "input": "First prompt"}, + ) + first_id = first.json()["id"] + + second = client.post( + "/v1/responses", + json={ + "model": "test-model", + "previous_response_id": first_id, + "input": "Second prompt", + }, + ) + second_id = second.json()["id"] + + third = client.post( + "/v1/responses", + json={ + "model": "test-model", + "previous_response_id": second_id, + "input": "Third prompt", + }, + ) + + assert third.status_code == 200 + third_messages = engine.chat.call_args_list[2].kwargs["messages"] + assert third_messages[0]["role"] == "user" + assert third_messages[0]["content"] == "First prompt" + assert third_messages[1]["role"] == "assistant" + assert third_messages[1]["content"] == "First answer" + assert third_messages[2]["role"] == "user" + assert third_messages[2]["content"] == "Second prompt" + assert third_messages[3]["role"] == "assistant" + assert third_messages[3]["content"] == "Second answer" + assert third_messages[4]["role"] == "user" + assert third_messages[4]["content"] == "Third prompt" + + def test_previous_response_id_does_not_carry_prior_instructions(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("First answer"), _output("Second answer")) + srv._engine = engine + + first = client.post( + "/v1/responses", + json={ + "model": "test-model", + "instructions": "First system instruction", + "input": "First prompt", + }, + ) + first_id = first.json()["id"] + + second = client.post( + "/v1/responses", + json={ + "model": "test-model", + "instructions": "Second system instruction", + "previous_response_id": first_id, + "input": "Follow-up prompt", + }, + ) + + assert second.status_code == 200 + second_messages = engine.chat.call_args_list[1].kwargs["messages"] + assert second_messages[0]["role"] == "system" + assert second_messages[0]["content"] == "Second system instruction" + assert "First system instruction" not in second_messages[0]["content"] + assert second_messages[1]["role"] == "user" + assert second_messages[1]["content"] == "First prompt" + assert second_messages[2]["role"] == "assistant" + assert second_messages[3]["role"] == "user" + + def test_previous_response_id_preserves_prior_system_message_items(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("First answer"), _output("Second answer")) + srv._engine = engine + + first = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": [ + {"type": "message", "role": "system", "content": "Persist me"}, + {"type": "message", "role": "user", "content": "First prompt"}, + ], + }, + ) + first_id = first.json()["id"] + + second = client.post( + "/v1/responses", + json={ + "model": "test-model", + "previous_response_id": first_id, + "input": "Follow-up prompt", + }, + ) + + assert second.status_code == 200 + second_messages = engine.chat.call_args_list[1].kwargs["messages"] + assert second_messages[0]["role"] == "system" + assert second_messages[0]["content"] == "Persist me" + assert second_messages[1]["role"] == "user" + assert second_messages[1]["content"] == "First prompt" + assert second_messages[2]["role"] == "assistant" + assert second_messages[3]["role"] == "user" + + def test_previous_response_id_missing_returns_404(self, client): + import vllm_mlx.server as srv + + srv._engine = _mock_engine(_output("unused")) + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "previous_response_id": "resp_missing", + "input": "Follow-up prompt", + }, + ) + + assert resp.status_code == 404 + assert "resp_missing" in resp.json()["detail"] + + def test_developer_role_is_normalized_to_system(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Ready")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": [ + {"type": "message", "role": "user", "content": "Hi"}, + {"type": "message", "role": "developer", "content": "Be terse"}, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "Be terse" + assert messages[1]["role"] == "user" + assert messages[1]["content"] == "Hi" + + def test_instructions_and_developer_message_are_merged(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Ready")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "instructions": "System instructions", + "input": [ + { + "type": "message", + "role": "developer", + "content": "Developer note", + }, + {"type": "message", "role": "user", "content": "Hi"}, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert len([m for m in messages if m["role"] == "system"]) == 1 + assert messages[0]["role"] == "system" + assert "System instructions" in messages[0]["content"] + assert "Developer note" in messages[0]["content"] + assert messages[1]["role"] == "user" + + def test_function_call_output_input_is_mapped_cleanly(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Done")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": [ + {"type": "message", "role": "user", "content": "Run it"}, + { + "type": "function_call", + "call_id": "call_1", + "name": "shell", + "arguments": '{"cmd":"pwd"}', + }, + { + "type": "function_call_output", + "call_id": "call_1", + "output": "/tmp/work", + }, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert messages[1]["role"] == "assistant" + assert "[Calling tool: shell(" in messages[1]["content"] + assert messages[2]["role"] == "user" + assert "[Tool Result (call_1)]" in messages[2]["content"] + assert "/tmp/work" in messages[2]["content"] + + def test_unsupported_tools_and_items_do_not_fail(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Fallback answer")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": [ + {"type": "message", "role": "user", "content": "Answer directly"}, + { + "type": "web_search_call", + "status": "completed", + "action": {"type": "search", "query": "ignored"}, + }, + ], + "tools": [ + {"type": "web_search_preview"}, + {"type": "file_search", "vector_store_ids": ["vs_123"]}, + { + "type": "function", + "name": "shell", + "parameters": {"type": "object", "properties": {}}, + }, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert messages[0]["role"] == "system" + assert "not available on this backend" in messages[0]["content"] + assert messages[1]["role"] == "user" + assert engine.chat.call_args.kwargs["tools"][0]["type"] == "function" + + def test_function_call_response_item(self, client): + import vllm_mlx.server as srv + + srv._engine = _mock_engine( + _output('{"name":"shell","arguments":{"cmd":"pwd"}}') + ) + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": "Use a tool", + "tools": [ + { + "type": "function", + "name": "shell", + "parameters": {"type": "object", "properties": {}}, + } + ], + }, + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["output"][0]["type"] == "function_call" + assert body["output"][0]["name"] == "shell" + assert body["output_text"] == "" + + def test_store_false_skips_persistence(self, client): + import vllm_mlx.server as srv + + srv._engine = _mock_engine(_output("Ephemeral answer")) + + first = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": "Do not store this", + "store": False, + }, + ) + + assert first.status_code == 200 + assert first.json()["store"] is False + assert first.json()["id"] not in srv._responses_store + + second = client.post( + "/v1/responses", + json={ + "model": "test-model", + "previous_response_id": first.json()["id"], + "input": "Follow-up prompt", + }, + ) + + assert second.status_code == 404 + + def test_responses_store_is_lru_bounded(self, client): + import vllm_mlx.server as srv + + srv._RESPONSES_STORE_MAX_SIZE = 2 + srv._engine = _mock_engine( + _output("First answer"), + _output("Second answer"), + _output("Third answer"), + ) + + first = client.post( + "/v1/responses", + json={"model": "test-model", "input": "First prompt"}, + ) + second = client.post( + "/v1/responses", + json={"model": "test-model", "input": "Second prompt"}, + ) + third = client.post( + "/v1/responses", + json={"model": "test-model", "input": "Third prompt"}, + ) + + assert first.status_code == 200 + assert second.status_code == 200 + assert third.status_code == 200 + assert list(srv._responses_store) == [second.json()["id"], third.json()["id"]] + assert first.json()["id"] not in srv._responses_store + + def test_streaming_response_returns_sse_events(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("unused")) + engine.chat = AsyncMock( + side_effect=AssertionError("stream path should not call chat") + ) + engine._stream_outputs = [ + _stream_output("Hello ", completion_tokens=1), + _stream_output("stream", completion_tokens=2, finish_reason="stop"), + ] + srv._engine = engine + + with client.stream( + "POST", + "/v1/responses", + json={"model": "test-model", "input": "Hello", "stream": True}, + ) as resp: + body = "".join(resp.iter_text()) + + assert resp.status_code == 200 + assert "event: response.created" in body + assert "event: response.output_text.delta" in body + assert "Hello stream" in body + assert "event: response.completed" in body + assert len(engine._stream_calls) == 1 + engine.chat.assert_not_awaited() + + def test_streaming_response_sequence_metadata_is_monotonic(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("unused")) + engine.chat = AsyncMock( + side_effect=AssertionError("stream path should not call chat") + ) + engine._stream_outputs = [ + _stream_output("Hello ", completion_tokens=1), + _stream_output("stream", completion_tokens=2, finish_reason="stop"), + ] + srv._engine = engine + + with client.stream( + "POST", + "/v1/responses", + json={"model": "test-model", "input": "Hello", "stream": True}, + ) as resp: + body = "".join(resp.iter_text()) + + assert resp.status_code == 200 + events = _parse_sse_events(body) + assert [event_type for event_type, _ in events[:2]] == [ + "response.created", + "response.in_progress", + ] + sequence_numbers = [payload["sequence_number"] for _, payload in events] + assert sequence_numbers == sorted(sequence_numbers) + created_payload = events[0][1] + completed_payload = next( + payload + for event_type, payload in events + if event_type == "response.completed" + ) + assert created_payload["response"]["id"] == completed_payload["response"]["id"] + assert completed_payload["response"]["output_text"] == "Hello stream" + + def test_streaming_response_bracket_tool_call_does_not_leak_text( + self, client, monkeypatch + ): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("unused")) + engine.chat = AsyncMock( + side_effect=AssertionError("stream path should not call chat") + ) + engine._stream_outputs = [ + _stream_output('[Calling tool: add({"a": 1, "b": 2})'), + _stream_output("]", completion_tokens=2, finish_reason="stop"), + ] + srv._engine = engine + monkeypatch.setattr(srv, "_enable_auto_tool_choice", True) + monkeypatch.setattr(srv, "_tool_call_parser", "qwen3") + monkeypatch.setattr(srv, "_tool_parser_instance", None) + monkeypatch.setattr(srv, "_reasoning_parser", None) + + with client.stream( + "POST", + "/v1/responses", + json={ + "model": "test-model", + "input": "Add two numbers", + "stream": True, + "tools": [ + { + "type": "function", + "name": "add", + "description": "Add two numbers", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"}, + }, + "required": ["a", "b"], + }, + } + ], + }, + ) as resp: + body = "".join(resp.iter_text()) + + assert resp.status_code == 200 + events = _parse_sse_events(body) + output_text_deltas = [ + payload["delta"] + for event_type, payload in events + if event_type == "response.output_text.delta" + ] + function_call_deltas = [ + payload + for event_type, payload in events + if event_type == "response.function_call_arguments.delta" + ] + + assert not any("[Calling tool:" in delta for delta in output_text_deltas) + assert len(function_call_deltas) == 1 + assert function_call_deltas[0]["delta"] == '{"a": 1, "b": 2}' + + def test_json_object_response_format_is_rejected(self, client): + import vllm_mlx.server as srv + + srv._engine = _mock_engine(_output("Hello")) + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": "Hello", + "text": {"format": {"type": "json_object"}}, + }, + ) + + assert resp.status_code == 400 + assert "json_object" in resp.json()["detail"] + + def test_reasoning_configuration_is_ignored(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Hello")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": "Hello", + "reasoning": {"effort": "xhigh"}, + }, + ) + + assert resp.status_code == 200 + assert engine.chat.await_count == 1 + + def test_reasoning_input_item_is_accepted(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Hello")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": [ + {"type": "message", "role": "user", "content": "Hello"}, + { + "type": "reasoning", + "content": [{"type": "reasoning_text", "text": "x"}], + }, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert messages[0]["role"] == "user" + assert messages[0]["content"] == "Hello" + assert messages[1]["role"] == "assistant" + assert messages[1]["content"] == "x" + + def test_length_finish_reason_marks_response_incomplete(self, client): + import vllm_mlx.server as srv + + output = _output("Cut off", completion_tokens=5) + output.finish_reason = "length" + srv._engine = _mock_engine(output) + + resp = client.post( + "/v1/responses", + json={"model": "test-model", "input": "Hello", "max_output_tokens": 5}, + ) + + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "incomplete" + assert body["incomplete_details"] == {"reason": "max_output_tokens"} diff --git a/tests/test_server.py b/tests/test_server.py index 08b169bd1..63a290f67 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,8 +4,12 @@ import json import platform import sys +from unittest.mock import AsyncMock, MagicMock, patch +from types import SimpleNamespace import pytest +from fastapi.testclient import TestClient +from pydantic import ValidationError # Skip all tests if not on Apple Silicon pytestmark = pytest.mark.skipif( @@ -153,6 +157,17 @@ def test_request_with_video_params(self): assert request.video_fps == 2.0 assert request.video_max_frames == 16 + def test_max_tokens_must_be_positive(self): + """Chat completion requests reject zero or negative max_tokens.""" + from vllm_mlx.server import ChatCompletionRequest, Message + + with pytest.raises(ValidationError): + ChatCompletionRequest( + model="test-model", + messages=[Message(role="user", content="Hello")], + max_tokens=0, + ) + class TestCompletionRequest: """Test CompletionRequest model.""" @@ -167,10 +182,182 @@ def test_basic_completion_request(self): assert request.prompt == "Once upon a time" assert request.max_tokens is None # uses _default_max_tokens when None + def test_max_tokens_must_be_positive(self): + """Completion requests reject zero or negative max_tokens.""" + from vllm_mlx.server import CompletionRequest + + with pytest.raises(ValidationError): + CompletionRequest( + model="test-model", prompt="Once upon a time", max_tokens=0 + ) + + +class TestAnthropicRequest: + """Test Anthropic request model.""" + + def test_max_tokens_must_be_positive(self): + """Anthropic requests reject zero or negative max_tokens.""" + from vllm_mlx.api.anthropic_models import AnthropicRequest + + with pytest.raises(ValidationError): + AnthropicRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=0, + ) + + +class TestMCPExecuteEndpoint: + """Test MCP execute endpoint sandbox routing.""" + + @pytest.fixture() + def client(self): + from fastapi.testclient import TestClient + + from vllm_mlx.server import app + + return TestClient(app) + + def test_execute_routes_through_executor(self, client): + """REST MCP execute should use ToolExecutor, not raw manager.execute_tool.""" + import vllm_mlx.server as srv + + mock_manager = MagicMock() + mock_manager.execute_tool = AsyncMock( + side_effect=AssertionError("manager.execute_tool should not be called") + ) + + mock_result = MagicMock() + mock_result.tool_name = "filesystem__read_file" + mock_result.content = "hello" + mock_result.is_error = False + mock_result.error_message = None + + mock_executor = MagicMock() + mock_executor.execute_tool_calls = AsyncMock( + return_value=[(mock_result, "mcp-test")] + ) + + original_manager = srv._mcp_manager + original_executor = srv._mcp_executor + srv._mcp_manager = mock_manager + srv._mcp_executor = mock_executor + try: + resp = client.post( + "/v1/mcp/execute", + json={ + "tool_name": "filesystem__read_file", + "arguments": {"path": "/tmp/test.txt"}, + }, + ) + finally: + srv._mcp_manager = original_manager + srv._mcp_executor = original_executor + + assert resp.status_code == 200 + body = resp.json() + assert body["tool_name"] == "filesystem__read_file" + assert body["content"] == "hello" + assert body["is_error"] is False + mock_executor.execute_tool_calls.assert_awaited_once() + mock_manager.execute_tool.assert_not_awaited() + + def test_execute_returns_sandbox_blocked_result(self, client): + """REST MCP execute should surface sandbox blocks via executor result.""" + import vllm_mlx.server as srv + + mock_result = MagicMock() + mock_result.tool_name = "filesystem__read_file" + mock_result.content = None + mock_result.is_error = True + mock_result.error_message = "Tool 'read_file' is blocked by security policy" + + mock_executor = MagicMock() + mock_executor.execute_tool_calls = AsyncMock( + return_value=[(mock_result, "mcp-test")] + ) + + original_manager = srv._mcp_manager + original_executor = srv._mcp_executor + srv._mcp_manager = MagicMock() + srv._mcp_executor = mock_executor + try: + resp = client.post( + "/v1/mcp/execute", + json={ + "tool_name": "filesystem__read_file", + "arguments": {"path": "../secret.txt"}, + }, + ) + finally: + srv._mcp_manager = original_manager + srv._mcp_executor = original_executor + + assert resp.status_code == 200 + body = resp.json() + assert body["is_error"] is True + assert "blocked by security policy" in body["error_message"] + class TestServeCli: """Test serve CLI argument parsing.""" + def test_trust_remote_code_flag_defaults_false(self): + """Serve CLI should require explicit opt-in for remote code loading.""" + from vllm_mlx.cli import create_parser + + parser = create_parser() + args = parser.parse_args(["serve", "mlx-community/Llama-3.2-3B-Instruct-4bit"]) + assert args.trust_remote_code is False + + args = parser.parse_args( + [ + "serve", + "mlx-community/Llama-3.2-3B-Instruct-4bit", + "--trust-remote-code", + ] + ) + assert args.trust_remote_code is True + + def test_host_defaults_to_localhost(self): + """Serve parsers should bind only to localhost unless overridden.""" + from vllm_mlx.cli import create_parser as create_cli_parser + from vllm_mlx.server import create_parser as create_server_parser + + cli_parser = create_cli_parser() + cli_args = cli_parser.parse_args( + ["serve", "mlx-community/Llama-3.2-3B-Instruct-4bit"] + ) + assert cli_args.host == "127.0.0.1" + + server_parser = create_server_parser() + server_args = server_parser.parse_args( + ["--model", "mlx-community/Llama-3.2-3B-Instruct-4bit"] + ) + assert server_args.host == "127.0.0.1" + + def test_max_request_tokens_defaults_and_overrides(self): + """Serve CLI exposes a separate request max_tokens ceiling.""" + from vllm_mlx.cli import create_parser + + parser = create_parser() + args = parser.parse_args(["serve", "mlx-community/Llama-3.2-3B-Instruct-4bit"]) + assert args.max_tokens == 32768 + assert args.max_request_tokens == 32768 + + args = parser.parse_args( + [ + "serve", + "mlx-community/Llama-3.2-3B-Instruct-4bit", + "--max-tokens", + "2048", + "--max-request-tokens", + "4096", + ] + ) + assert args.max_tokens == 2048 + assert args.max_request_tokens == 4096 + def test_tool_call_parser_accepts_harmony_aliases(self): """GPT-OSS/Harmony parsers should be selectable from the serve CLI.""" from vllm_mlx.cli import create_parser @@ -202,6 +389,222 @@ def test_tool_call_parser_accepts_harmony_aliases(self): assert args.tool_call_parser == "gpt-oss" + def test_default_chat_template_kwargs_accepts_json_object(self): + """Serve CLI should parse default chat template kwargs from a JSON object.""" + from vllm_mlx.cli import create_parser + + parser = create_parser() + args = parser.parse_args( + [ + "serve", + "mlx-community/Qwen3-0.6B-8bit", + "--default-chat-template-kwargs", + '{"enable_thinking": false}', + ] + ) + + assert args.default_chat_template_kwargs == {"enable_thinking": False} + + def test_default_chat_template_kwargs_help_mentions_empty_request_behavior( + self, capsys + ): + """Serve CLI help should explain empty request kwargs keep server defaults.""" + from vllm_mlx.cli import create_parser + + parser = create_parser() + with pytest.raises(SystemExit): + parser.parse_args(["serve", "mlx-community/Qwen3-0.6B-8bit", "--help"]) + + captured = capsys.readouterr() + normalized = " ".join(captured.out.split()) + assert "omitted or empty" in normalized + assert "server defaults" in normalized + + @pytest.mark.parametrize("bad_json", ["{not-json}", "{"]) + def test_default_chat_template_kwargs_rejects_malformed_json( + self, bad_json, capsys + ): + """Serve CLI should fail fast on malformed JSON input.""" + from vllm_mlx.cli import create_parser + + parser = create_parser() + with pytest.raises(SystemExit): + parser.parse_args( + [ + "serve", + "mlx-community/Qwen3-0.6B-8bit", + "--default-chat-template-kwargs", + bad_json, + ] + ) + + captured = capsys.readouterr() + assert "--default-chat-template-kwargs" in captured.err + assert "JSON object" in captured.err + + @pytest.mark.parametrize("non_object", ["[]", "true", "123"]) + def test_default_chat_template_kwargs_rejects_non_object_json( + self, non_object, capsys + ): + """Serve CLI should reject valid JSON values that are not objects.""" + from vllm_mlx.cli import create_parser + + parser = create_parser() + with pytest.raises(SystemExit): + parser.parse_args( + [ + "serve", + "mlx-community/Qwen3-0.6B-8bit", + "--default-chat-template-kwargs", + non_object, + ] + ) + + captured = capsys.readouterr() + assert "--default-chat-template-kwargs" in captured.err + assert "JSON object" in captured.err + + +class TestStandaloneServerCli: + """Test standalone server CLI argument parsing.""" + + def test_trust_remote_code_flag_defaults_false(self): + """Standalone server should require explicit opt-in for remote code loading.""" + from vllm_mlx.server import create_parser + + parser = create_parser() + args = parser.parse_args( + ["--model", "mlx-community/Llama-3.2-3B-Instruct-4bit"] + ) + assert args.trust_remote_code is False + + args = parser.parse_args( + [ + "--model", + "mlx-community/Llama-3.2-3B-Instruct-4bit", + "--trust-remote-code", + ] + ) + assert args.trust_remote_code is True + + def test_default_chat_template_kwargs_accepts_json_object(self): + """Standalone server should parse default chat template kwargs JSON.""" + from vllm_mlx.server import create_parser + + parser = create_parser() + args = parser.parse_args( + [ + "--model", + "mlx-community/Qwen3-0.6B-8bit", + "--default-chat-template-kwargs", + '{"enable_thinking": false}', + ] + ) + + assert args.default_chat_template_kwargs == {"enable_thinking": False} + + def test_default_chat_template_kwargs_help_mentions_empty_request_behavior( + self, capsys + ): + """Standalone help should explain empty request kwargs keep server defaults.""" + from vllm_mlx.server import create_parser + + parser = create_parser() + with pytest.raises(SystemExit): + parser.parse_args(["--help"]) + + captured = capsys.readouterr() + normalized = " ".join(captured.out.split()) + assert "omitted or empty" in normalized + assert "server defaults" in normalized + + @pytest.mark.parametrize("bad_json", ["{not-json}", "{"]) + def test_default_chat_template_kwargs_rejects_malformed_json( + self, bad_json, capsys + ): + """Standalone server should fail fast on malformed JSON input.""" + from vllm_mlx.server import create_parser + + parser = create_parser() + with pytest.raises(SystemExit): + parser.parse_args( + [ + "--model", + "mlx-community/Qwen3-0.6B-8bit", + "--default-chat-template-kwargs", + bad_json, + ] + ) + + captured = capsys.readouterr() + assert "--default-chat-template-kwargs" in captured.err + assert "JSON object" in captured.err + + @pytest.mark.parametrize("non_object", ["[]", "false", "0"]) + def test_default_chat_template_kwargs_rejects_non_object_json( + self, non_object, capsys + ): + """Standalone server should reject valid JSON values that are not objects.""" + from vllm_mlx.server import create_parser + + parser = create_parser() + with pytest.raises(SystemExit): + parser.parse_args( + [ + "--model", + "mlx-community/Qwen3-0.6B-8bit", + "--default-chat-template-kwargs", + non_object, + ] + ) + + captured = capsys.readouterr() + assert "--default-chat-template-kwargs" in captured.err + assert "JSON object" in captured.err + + +class TestLoadModelTrustRemoteCode: + """Test load_model trust_remote_code wiring into engine constructors.""" + + def test_load_model_simple_defaults_trust_remote_code_false(self): + """SimpleEngine should receive trust_remote_code=False by default.""" + from vllm_mlx import server + + fake_engine = MagicMock() + fake_loop = MagicMock() + + with ( + patch.object( + server, "SimpleEngine", return_value=fake_engine + ) as mock_engine, + patch.object(server, "_detect_native_tool_support", return_value=False), + patch("vllm_mlx.server.asyncio.new_event_loop", return_value=fake_loop), + patch("vllm_mlx.server.asyncio.set_event_loop"), + ): + server.load_model("test-model", use_batching=False) + + assert mock_engine.call_args.kwargs["trust_remote_code"] is False + + def test_load_model_batched_forwards_explicit_trust_remote_code(self): + """BatchedEngine should receive explicit trust_remote_code opt-in.""" + from vllm_mlx import server + + fake_engine = MagicMock() + + with ( + patch.object( + server, "BatchedEngine", return_value=fake_engine + ) as mock_engine, + patch.object(server, "_detect_native_tool_support", return_value=False), + ): + server.load_model( + "test-model", + use_batching=True, + trust_remote_code=True, + ) + + assert mock_engine.call_args.kwargs["trust_remote_code"] is True + # ============================================================================= # Helper Function Tests @@ -228,6 +631,17 @@ def test_is_mllm_model_patterns(self): assert not is_mllm_model("mlx-community/Mistral-7B-Instruct-4bit") assert not is_mllm_model("mlx-community/Qwen2-7B-Instruct-4bit") + def test_sanitize_log_text_escapes_control_characters(self): + """Untrusted log text should not contain raw control characters.""" + from vllm_mlx.server import _sanitize_log_text + + text = "line1\nline2\r\t\u2028\x1b[31m" + sanitized = _sanitize_log_text(text) + + assert sanitized == r"line1\nline2\r\t\u2028\x1b[31m" + assert "\n" not in sanitized.replace(r"\n", "") + assert "\r" not in sanitized.replace(r"\r", "") + def test_extract_multimodal_content_text_only(self): """Test extracting content from text-only messages.""" from vllm_mlx.server import extract_multimodal_content, Message @@ -586,6 +1000,248 @@ def test_completion_request_timeout_field(self): assert request_with_timeout.timeout == 120.0 +class TestMaxTokensLimit: + """Test server-side max_tokens ceiling enforcement.""" + + def test_resolve_request_max_tokens(self): + """Explicit requests must stay within the configured server ceiling.""" + import vllm_mlx.server as server + + original_default = server._default_max_tokens + original_limit = server._max_request_tokens + try: + server._default_max_tokens = 1024 + server._max_request_tokens = 2048 + + assert server._resolve_request_max_tokens(None) == 1024 + assert server._resolve_request_max_tokens(512) == 512 + + with pytest.raises(server.HTTPException) as exc_info: + server._resolve_request_max_tokens(4096) + + assert exc_info.value.status_code == 400 + assert "server limit" in exc_info.value.detail + finally: + server._default_max_tokens = original_default + server._max_request_tokens = original_limit + + @pytest.mark.anyio + async def test_create_completion_rejects_over_limit_before_engine_lookup( + self, monkeypatch + ): + """Completions should reject oversized requests at the API boundary.""" + from vllm_mlx.server import CompletionRequest, create_completion + import vllm_mlx.server as server + + monkeypatch.setattr(server, "_max_request_tokens", 1024) + monkeypatch.setattr( + server, + "get_engine", + lambda: (_ for _ in ()).throw(AssertionError("engine should not load")), + ) + + request = CompletionRequest( + model="test-model", + prompt="Once upon a time", + max_tokens=2048, + ) + + with pytest.raises(server.HTTPException) as exc_info: + await create_completion(request, raw_request=None) + + assert exc_info.value.status_code == 400 + + @pytest.mark.anyio + async def test_create_chat_completion_rejects_over_limit_before_engine_lookup( + self, monkeypatch + ): + """Chat completions should reject oversized requests at the API boundary.""" + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + create_chat_completion, + ) + import vllm_mlx.server as server + + monkeypatch.setattr(server, "_max_request_tokens", 1024) + monkeypatch.setattr( + server, + "get_engine", + lambda: (_ for _ in ()).throw(AssertionError("engine should not load")), + ) + + request = ChatCompletionRequest( + model="test-model", + messages=[Message(role="user", content="Hello")], + max_tokens=2048, + ) + + with pytest.raises(server.HTTPException) as exc_info: + await create_chat_completion(request, raw_request=None) + + assert exc_info.value.status_code == 400 + + @pytest.mark.anyio + @pytest.mark.parametrize( + ("user_stop", "expected_stop"), + [ + (None, ["<|tool_response>"]), + (["END_USER"], ["END_USER", "<|tool_response>"]), + ], + ) + async def test_create_chat_completion_merges_parser_stop_tokens( + self, monkeypatch, user_stop, expected_stop + ): + """Parser-declared stop tokens should be merged into chat kwargs.""" + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + create_chat_completion, + ) + import vllm_mlx.server as server + + captured = {} + helper_calls = [] + + class FakeEngine: + model_name = "fake-engine" + is_mllm = False + preserve_native_tool_format = False + + async def chat(self, messages, **kwargs): + captured["messages"] = messages + captured["kwargs"] = kwargs + return GenerationOutput( + text="ok", + prompt_tokens=5, + completion_tokens=2, + finish_reason="stop", + ) + + fake_engine = FakeEngine() + + async def fake_acquire( + raw_request, *, total_timeout=None, deadline=None, count_activity=True + ): + return fake_engine + + async def fake_release(*, count_activity=True): + return None + + def fake_get_parser_stop_tokens(parser_name, user_stops): + helper_calls.append((parser_name, user_stops)) + merged = list(user_stops or []) + if "<|tool_response>" not in merged: + merged.append("<|tool_response>") + return merged + + monkeypatch.setattr(server, "_validate_model_name", lambda _m: None) + monkeypatch.setattr(server, "_acquire_default_engine_for_request", fake_acquire) + monkeypatch.setattr(server, "_release_default_engine", fake_release) + monkeypatch.setattr( + server, "get_parser_stop_tokens", fake_get_parser_stop_tokens + ) + monkeypatch.setattr(server, "_model_name", "served-model") + monkeypatch.setattr(server, "_default_max_tokens", 128) + monkeypatch.setattr(server, "_default_timeout", 30.0) + monkeypatch.setattr(server, "_enable_auto_tool_choice", True) + monkeypatch.setattr(server, "_tool_call_parser", "fake") + monkeypatch.setattr(server, "_tool_parser_instance", None) + + request = ChatCompletionRequest( + model="served-model", + messages=[Message(role="user", content="Hello")], + max_tokens=16, + stop=user_stop, + ) + + response = await create_chat_completion(request, raw_request=None) + + assert helper_calls + assert all(call == ("fake", user_stop) for call in helper_calls) + assert captured["kwargs"]["stop"] == expected_stop + assert response.choices[0].message.content == "ok" + + @pytest.mark.anyio + async def test_create_anthropic_message_rejects_over_limit_before_engine_lookup( + self, monkeypatch + ): + """Anthropic requests should reject oversized requests before engine use.""" + from vllm_mlx.server import create_anthropic_message + import vllm_mlx.server as server + + class FakeRequest: + async def json(self): + return { + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 2048, + } + + monkeypatch.setattr(server, "_max_request_tokens", 1024) + monkeypatch.setattr( + server, + "get_engine", + lambda: (_ for _ in ()).throw(AssertionError("engine should not load")), + ) + + with pytest.raises(server.HTTPException) as exc_info: + await create_anthropic_message(FakeRequest()) + + assert exc_info.value.status_code == 400 + + +class TestChatTemplateKwargsResolver: + """Test default chat template kwargs precedence contract.""" + + def test_resolver_prefers_request_values_over_server_defaults(self, monkeypatch): + """Request kwargs should override server defaults key-by-key.""" + import vllm_mlx.server as server + + monkeypatch.setattr( + server, + "_default_chat_template_kwargs", + {"enable_thinking": False, "temperature_hint": "server"}, + raising=False, + ) + + resolved = server._resolve_chat_template_kwargs( + {"enable_thinking": True, "request_only": 1} + ) + assert resolved == { + "enable_thinking": True, + "temperature_hint": "server", + "request_only": 1, + } + + def test_resolver_uses_server_defaults_when_request_omits_kwargs(self, monkeypatch): + """Resolver should return server defaults when request kwargs are absent.""" + import vllm_mlx.server as server + + monkeypatch.setattr( + server, + "_default_chat_template_kwargs", + {"enable_thinking": False}, + raising=False, + ) + + assert server._resolve_chat_template_kwargs(None) == {"enable_thinking": False} + + def test_resolver_returns_empty_dict_when_no_values_are_provided(self, monkeypatch): + """Resolver should produce an empty dict when neither source provides values.""" + import vllm_mlx.server as server + + monkeypatch.setattr( + server, + "_default_chat_template_kwargs", + None, + raising=False, + ) + + assert server._resolve_chat_template_kwargs(None) == {} + + class TestAPIKeyVerification: """Test API key verification with timing attack prevention.""" @@ -636,6 +1292,91 @@ def test_verify_api_key_rejects_invalid(self): finally: server._api_key = original_key + +class TestLogAndExceptionSanitization: + """Test request log previews and internal error responses.""" + + @pytest.mark.anyio + async def test_create_completion_logs_sanitized_prompt_preview( + self, monkeypatch, caplog + ): + """Prompt previews should escape control characters before logging.""" + from vllm_mlx.server import CompletionRequest, create_completion + import vllm_mlx.server as server + + class DummyEngine: + async def generate(self, **kwargs): + return SimpleNamespace( + text="ok", + finish_reason="stop", + completion_tokens=1, + prompt_tokens=1, + ) + + async def fake_wait(task, raw_request, timeout): + return await task + + monkeypatch.setattr(server, "_model_name", "test-model") + monkeypatch.setattr(server, "get_engine", lambda: DummyEngine()) + monkeypatch.setattr(server, "_wait_with_disconnect", fake_wait) + + request = CompletionRequest( + model="test-model", + prompt="line1\nline2\t\x1b[31mred", + max_tokens=8, + ) + + with caplog.at_level("INFO"): + response = await create_completion(request, raw_request=None) + + assert response.choices[0].text == "ok" + preview_logs = [ + record.getMessage() + for record in caplog.records + if "prompt_preview=" in record.getMessage() + ] + assert preview_logs + assert "line1\\nline2\\t\\x1b[31mred" in preview_logs[0] + assert "line1\nline2" not in preview_logs[0] + + @pytest.mark.anyio + async def test_create_embeddings_hides_internal_exception_details( + self, monkeypatch, caplog + ): + """Embedding failures should log sanitized details but return generic 500s.""" + from vllm_mlx.server import EmbeddingRequest, create_embeddings + import vllm_mlx.server as server + + class ExplodingEmbeddingEngine: + def count_tokens(self, texts): + raise RuntimeError("boom\nsecret\t\x1b[31m") + + monkeypatch.setattr(server, "_embedding_engine", ExplodingEmbeddingEngine()) + monkeypatch.setattr( + server, "load_embedding_model", lambda *args, **kwargs: None + ) + + request = EmbeddingRequest( + model="mlx-community/all-MiniLM-L6-v2-4bit", + input="hello", + ) + + with caplog.at_level("ERROR"): + with pytest.raises(server.HTTPException) as exc_info: + await create_embeddings(request) + + assert exc_info.value.status_code == 500 + assert exc_info.value.detail == "Embedding generation failed" + + error_logs = [ + record.getMessage() + for record in caplog.records + if "Embedding generation failed:" in record.getMessage() + ] + assert error_logs + assert r"boom\nsecret\t\x1b[31m" in error_logs[0] + assert "boom\nsecret" not in error_logs[0] + def test_verify_api_key_accepts_valid(self): """Test that valid API key is accepted.""" import asyncio @@ -681,9 +1422,83 @@ def test_rate_limiter_returns_retry_after(self): assert allowed is False assert retry_after is not None assert retry_after > 0 - assert retry_after <= 60 # Should be within a minute - def test_rate_limiter_window_cleanup(self): + +class TestEndpointSecurityDependencies: + """Test auth/rate-limit coverage on protected endpoints.""" + + @pytest.fixture + def client(self): + import vllm_mlx.server as server + + return TestClient(server.app) + + @pytest.fixture(autouse=True) + def restore_security_state(self): + import vllm_mlx.server as server + + original_key = server._api_key + original_limiter = server._rate_limiter + try: + yield + finally: + server._api_key = original_key + server._rate_limiter = original_limiter + + @pytest.mark.parametrize( + ("method", "path"), + [ + ("get", "/v1/status"), + ("get", "/v1/cache/stats"), + ("delete", "/v1/cache"), + ("post", "/v1/messages"), + ("post", "/v1/messages/count_tokens"), + ], + ) + def test_endpoints_require_api_key(self, client, method, path): + import vllm_mlx.server as server + + server._api_key = "test-secret" + + kwargs = {} + if method == "post": + kwargs["json"] = { + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 16, + } + + response = getattr(client, method)(path, **kwargs) + assert response.status_code == 401 + assert response.json()["detail"] == "API key required" + + @pytest.mark.parametrize("path", ["/v1/messages", "/v1/messages/count_tokens"]) + def test_anthropic_endpoints_apply_rate_limit(self, client, path, monkeypatch): + import vllm_mlx.server as server + + server._api_key = "test-secret" + + def deny_all(_client_id): + return False, 7 + + monkeypatch.setattr(server._rate_limiter, "enabled", True) + monkeypatch.setattr(server._rate_limiter, "is_allowed", deny_all) + + response = client.post( + path, + headers={"Authorization": "Bearer test-secret"}, + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "hi"}], + "max_tokens": 16, + }, + ) + + assert response.status_code == 429 + assert "Rate limit exceeded" in response.json()["detail"] + assert response.headers["Retry-After"] == "7" + + def test_rate_limiter_window_cleanup(self): """Test that rate limiter cleans up old requests from sliding window.""" from vllm_mlx.server import RateLimiter import time @@ -712,6 +1527,264 @@ def test_rate_limiter_window_cleanup(self): class TestStreamChatCompletion: """Tests for streaming chat completion behavior.""" + @pytest.mark.anyio + async def test_stream_without_parser_flags_emits_structured_tool_calls( + self, monkeypatch + ): + """Streaming tools should still parse without explicit parser flags.""" + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + stream_chat_completion, + ) + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-engine" + + async def stream_chat(self, messages, **kwargs): + chunks = [ + GenerationOutput(text="", new_text="", finished=False), + GenerationOutput( + text="", + new_text="", + finished=False, + ), + GenerationOutput( + text="", + new_text="/Users/testuser", + finished=False, + ), + GenerationOutput( + text="", + new_text="", + finished=False, + ), + GenerationOutput( + text="", + new_text="", + finished=True, + finish_reason="stop", + prompt_tokens=5, + completion_tokens=7, + ), + ] + for chunk in chunks: + yield chunk + + monkeypatch.setattr(server, "_model_name", "served-model") + monkeypatch.setattr(server, "_reasoning_parser", None) + monkeypatch.setattr(server, "_enable_auto_tool_choice", False) + monkeypatch.setattr(server, "_tool_call_parser", None) + monkeypatch.setattr(server, "_tool_parser_instance", None) + + request = ChatCompletionRequest( + model="served-model", + messages=[Message(role="user", content="hi")], + tools=[ + { + "type": "function", + "function": { + "name": "list_directory", + "description": "List files in a directory", + "parameters": { + "type": "object", + "properties": {"path": {"type": "string"}}, + "required": ["path"], + }, + }, + } + ], + stream=True, + ) + + chunks = [ + chunk + async for chunk in stream_chat_completion( + FakeEngine(), request.messages, request + ) + ] + + payloads = [ + json.loads(chunk.removeprefix("data: ").strip()) + for chunk in chunks + if chunk != "data: [DONE]\n\n" + ] + tool_payloads = [ + payload + for payload in payloads + if payload["choices"] and payload["choices"][0]["delta"].get("tool_calls") + ] + + assert len(tool_payloads) == 1 + delta = tool_payloads[0]["choices"][0]["delta"] + assert delta["tool_calls"][0]["function"]["name"] == "list_directory" + assert delta["tool_calls"][0]["function"]["arguments"] == ( + '{"path": "/Users/testuser"}' + ) + assert delta.get("content") is None + assert tool_payloads[0]["choices"][0]["finish_reason"] == "tool_calls" + assert tool_payloads[0]["usage"] == { + "prompt_tokens": 5, + "completion_tokens": 7, + "total_tokens": 12, + } + + @pytest.mark.anyio + async def test_stream_without_parser_flags_keeps_plain_text(self, monkeypatch): + """Generic streaming fallback should not interfere with normal text.""" + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + stream_chat_completion, + ) + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-engine" + + async def stream_chat(self, messages, **kwargs): + chunks = [ + GenerationOutput(text="", new_text="hello ", finished=False), + GenerationOutput( + text="", + new_text="world", + finished=True, + finish_reason="stop", + prompt_tokens=4, + completion_tokens=2, + ), + ] + for chunk in chunks: + yield chunk + + monkeypatch.setattr(server, "_model_name", "served-model") + monkeypatch.setattr(server, "_reasoning_parser", None) + monkeypatch.setattr(server, "_enable_auto_tool_choice", False) + monkeypatch.setattr(server, "_tool_call_parser", None) + monkeypatch.setattr(server, "_tool_parser_instance", None) + + request = ChatCompletionRequest( + model="served-model", + messages=[Message(role="user", content="hi")], + tools=[ + { + "type": "function", + "function": { + "name": "list_directory", + "description": "List files in a directory", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + stream=True, + ) + + chunks = [ + chunk + async for chunk in stream_chat_completion( + FakeEngine(), request.messages, request + ) + ] + payloads = [ + json.loads(chunk.removeprefix("data: ").strip()) + for chunk in chunks + if chunk != "data: [DONE]\n\n" + ] + + assert payloads[1]["choices"][0]["delta"]["content"] == "hello " + assert payloads[2]["choices"][0]["delta"]["content"] == "world" + assert payloads[2]["choices"][0]["finish_reason"] == "stop" + + @pytest.mark.anyio + async def test_auto_parser_streams_bare_bracket_tool_calls(self, monkeypatch): + """Bare bracket tool calls should stream as structured tool_calls.""" + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + stream_chat_completion, + ) + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-engine" + + async def stream_chat(self, messages, **kwargs): + chunks = [ + GenerationOutput(text="", new_text="[read(", finished=False), + GenerationOutput( + text="", + new_text='{"file_path": "/tmp/test.py"}', + finished=False, + ), + GenerationOutput( + text="", + new_text=")]", + finished=True, + finish_reason="stop", + prompt_tokens=4, + completion_tokens=3, + ), + ] + for chunk in chunks: + yield chunk + + monkeypatch.setattr(server, "_model_name", "served-model") + monkeypatch.setattr(server, "_reasoning_parser", None) + monkeypatch.setattr(server, "_enable_auto_tool_choice", True) + monkeypatch.setattr(server, "_tool_call_parser", "auto") + monkeypatch.setattr(server, "_tool_parser_instance", None) + + request = ChatCompletionRequest( + model="served-model", + messages=[Message(role="user", content="hi")], + tools=[ + { + "type": "function", + "function": { + "name": "read", + "description": "Read a file", + "parameters": { + "type": "object", + "properties": {"file_path": {"type": "string"}}, + "required": ["file_path"], + }, + }, + } + ], + stream=True, + ) + + chunks = [ + chunk + async for chunk in stream_chat_completion( + FakeEngine(), request.messages, request + ) + ] + + payloads = [ + json.loads(chunk.removeprefix("data: ").strip()) + for chunk in chunks + if chunk != "data: [DONE]\n\n" + ] + tool_payloads = [ + payload + for payload in payloads + if payload["choices"] and payload["choices"][0]["delta"].get("tool_calls") + ] + + assert len(tool_payloads) == 1 + delta = tool_payloads[0]["choices"][0]["delta"] + assert delta["tool_calls"][0]["function"]["name"] == "read" + assert delta["tool_calls"][0]["function"]["arguments"] == ( + '{"file_path": "/tmp/test.py"}' + ) + assert delta.get("content") is None + assert tool_payloads[0]["choices"][0]["finish_reason"] == "tool_calls" + @pytest.mark.anyio async def test_reasoning_stream_emits_structured_tool_calls(self, monkeypatch): """Tool markup after should emit tool_calls chunks.""" @@ -930,6 +2003,414 @@ def extract_tool_calls_streaming( assert payloads[2]["choices"][0]["finish_reason"] == "stop" +class TestReasoningAndToolCallsNonStreaming: + """Non-streaming coexistence of reasoning extraction and tool parsing.""" + + @pytest.fixture() + def client(self): + """Create a FastAPI test client.""" + from fastapi.testclient import TestClient + + from vllm_mlx.server import app + + return TestClient(app) + + def test_chat_completion_preserves_reasoning_with_tool_calls( + self, client, monkeypatch + ): + """Reasoning should survive when tool calls are present in final output.""" + import vllm_mlx.server as server + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.server import ToolCall, FunctionCall + + parsed_inputs = [] + + class FakeEngine: + model_name = "fake-engine" + is_mllm = False + preserve_native_tool_format = False + + async def chat(self, messages, **kwargs): + return GenerationOutput( + text="Need tool", + prompt_tokens=7, + completion_tokens=3, + finish_reason="stop", + ) + + class FakeReasoningParser: + def extract_reasoning(self, model_output): + assert model_output == "Need tool" + return "Need tool", "" + + def fake_parse_tool_calls(text, request): + parsed_inputs.append(text) + if text == "": + return None, [ + ToolCall( + id="call_1", + type="function", + function=FunctionCall( + name="get_weather", + arguments='{"city":"Paris"}', + ), + ) + ] + return text, None + + monkeypatch.setattr(server, "_engine", FakeEngine()) + monkeypatch.setattr(server, "_model_name", "test-model") + monkeypatch.setattr(server, "_default_timeout", 30.0) + monkeypatch.setattr(server, "_default_max_tokens", 128) + monkeypatch.setattr(server, "_api_key", None) + monkeypatch.setattr( + server, + "_rate_limiter", + server.RateLimiter(requests_per_minute=60, enabled=False), + ) + monkeypatch.setattr(server, "_reasoning_parser", FakeReasoningParser()) + monkeypatch.setattr( + server, "_parse_tool_calls_with_parser", fake_parse_tool_calls + ) + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Weather?"}], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ], + "max_tokens": 32, + }, + ) + + assert response.status_code == 200 + body = response.json() + choice = body["choices"][0] + assert parsed_inputs == [""] + assert choice["message"]["content"] is None + assert choice["message"]["reasoning_content"] == "Need tool" + assert choice["message"]["tool_calls"][0]["function"]["name"] == "get_weather" + assert choice["finish_reason"] == "tool_calls" + + def test_anthropic_message_preserves_thinking_with_tool_use( + self, client, monkeypatch + ): + """Anthropic non-streaming should emit thinking and tool_use blocks.""" + import vllm_mlx.server as server + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.server import ToolCall, FunctionCall + + parsed_inputs = [] + + class FakeEngine: + model_name = "fake-engine" + is_mllm = False + preserve_native_tool_format = False + + async def chat(self, messages, **kwargs): + return GenerationOutput( + text="Need tool", + prompt_tokens=11, + completion_tokens=4, + finish_reason="stop", + ) + + class FakeReasoningParser: + def extract_reasoning(self, model_output): + assert model_output == "Need tool" + return "Need tool", "" + + def fake_parse_tool_calls(text, request): + parsed_inputs.append(text) + if text == "": + return None, [ + ToolCall( + id="call_1", + type="function", + function=FunctionCall( + name="get_weather", + arguments='{"city":"Paris"}', + ), + ) + ] + return text, None + + monkeypatch.setattr(server, "_engine", FakeEngine()) + monkeypatch.setattr(server, "_model_name", "test-model") + monkeypatch.setattr(server, "_default_timeout", 30.0) + monkeypatch.setattr(server, "_default_max_tokens", 128) + monkeypatch.setattr(server, "_api_key", None) + monkeypatch.setattr( + server, + "_rate_limiter", + server.RateLimiter(requests_per_minute=60, enabled=False), + ) + monkeypatch.setattr(server, "_reasoning_parser", FakeReasoningParser()) + monkeypatch.setattr( + server, "_parse_tool_calls_with_parser", fake_parse_tool_calls + ) + + response = client.post( + "/v1/messages", + json={ + "model": "test-model", + "max_tokens": 32, + "messages": [{"role": "user", "content": "Weather?"}], + "tools": [ + { + "name": "get_weather", + "description": "Get weather", + "input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + } + ], + }, + ) + + assert response.status_code == 200 + body = response.json() + assert parsed_inputs == [""] + assert body["stop_reason"] == "tool_use" + assert [block["type"] for block in body["content"]] == ["thinking", "tool_use"] + assert body["content"][0]["thinking"] == "Need tool" + assert body["content"][1]["name"] == "get_weather" + assert body["content"][1]["input"] == {"city": "Paris"} + + def test_anthropic_message_applies_server_default_chat_template_kwargs( + self, client, monkeypatch + ): + """Anthropic endpoint should forward server default chat_template_kwargs.""" + import vllm_mlx.server as server + from vllm_mlx.engine.base import GenerationOutput + + captured = {} + + class FakeEngine: + model_name = "fake-engine" + is_mllm = False + preserve_native_tool_format = False + + async def chat(self, messages, **kwargs): + captured["messages"] = messages + captured["kwargs"] = kwargs + return GenerationOutput( + text="Final answer", + prompt_tokens=11, + completion_tokens=4, + finish_reason="stop", + ) + + monkeypatch.setattr(server, "_engine", FakeEngine()) + monkeypatch.setattr(server, "_model_name", "test-model") + monkeypatch.setattr(server, "_default_timeout", 30.0) + monkeypatch.setattr(server, "_default_max_tokens", 128) + monkeypatch.setattr(server, "_api_key", None) + monkeypatch.setattr( + server, + "_rate_limiter", + server.RateLimiter(requests_per_minute=60, enabled=False), + ) + monkeypatch.setattr(server, "_reasoning_parser", None) + monkeypatch.setattr( + server, + "_default_chat_template_kwargs", + {"enable_thinking": False}, + raising=False, + ) + + response = client.post( + "/v1/messages", + json={ + "model": "test-model", + "max_tokens": 32, + "messages": [{"role": "user", "content": "Weather?"}], + }, + ) + + assert response.status_code == 200 + assert captured["kwargs"]["chat_template_kwargs"] == {"enable_thinking": False} + + def test_anthropic_message_request_kwargs_override_server_defaults( + self, client, monkeypatch + ): + """Anthropic request chat_template_kwargs should override server defaults.""" + import vllm_mlx.server as server + from vllm_mlx.engine.base import GenerationOutput + + captured = {} + + class FakeEngine: + model_name = "fake-engine" + is_mllm = False + preserve_native_tool_format = False + + async def chat(self, messages, **kwargs): + captured["messages"] = messages + captured["kwargs"] = kwargs + return GenerationOutput( + text="Final answer", + prompt_tokens=11, + completion_tokens=4, + finish_reason="stop", + ) + + monkeypatch.setattr(server, "_engine", FakeEngine()) + monkeypatch.setattr(server, "_model_name", "test-model") + monkeypatch.setattr(server, "_default_timeout", 30.0) + monkeypatch.setattr(server, "_default_max_tokens", 128) + monkeypatch.setattr(server, "_api_key", None) + monkeypatch.setattr( + server, + "_rate_limiter", + server.RateLimiter(requests_per_minute=60, enabled=False), + ) + monkeypatch.setattr(server, "_reasoning_parser", None) + monkeypatch.setattr( + server, + "_default_chat_template_kwargs", + {"enable_thinking": False, "server_default_only": "yes"}, + raising=False, + ) + + response = client.post( + "/v1/messages", + json={ + "model": "test-model", + "max_tokens": 32, + "messages": [{"role": "user", "content": "Weather?"}], + "chat_template_kwargs": { + "enable_thinking": True, + "request_only": 1, + }, + }, + ) + + assert response.status_code == 200 + assert captured["kwargs"]["chat_template_kwargs"] == { + "enable_thinking": True, + "server_default_only": "yes", + "request_only": 1, + } + + def test_chat_completion_prepares_messages_once_in_non_stream_path( + self, client, monkeypatch + ): + """Chat non-streaming should prepare request messages a single time.""" + import vllm_mlx.server as server + from vllm_mlx.engine.base import GenerationOutput + + extract_calls = {"count": 0} + + class FakeEngine: + model_name = "fake-engine" + is_mllm = False + preserve_native_tool_format = False + + async def chat(self, messages, **kwargs): + return GenerationOutput( + text="done", + prompt_tokens=3, + completion_tokens=1, + finish_reason="stop", + ) + + def fake_extract(messages, preserve_native_format=False): + extract_calls["count"] += 1 + return ([{"role": "user", "content": "hi"}], [], []) + + monkeypatch.setattr(server, "_engine", FakeEngine()) + monkeypatch.setattr(server, "_model_name", "test-model") + monkeypatch.setattr(server, "_default_timeout", 30.0) + monkeypatch.setattr(server, "_default_max_tokens", 128) + monkeypatch.setattr(server, "_api_key", None) + monkeypatch.setattr( + server, + "_rate_limiter", + server.RateLimiter(requests_per_minute=60, enabled=False), + ) + monkeypatch.setattr(server, "extract_multimodal_content", fake_extract) + monkeypatch.setattr(server, "_reasoning_parser", None) + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 8, + }, + ) + + assert response.status_code == 200 + assert extract_calls["count"] == 1 + + def test_anthropic_message_prepares_messages_once_in_non_stream_path( + self, client, monkeypatch + ): + """Anthropic non-streaming should prepare request messages a single time.""" + import vllm_mlx.server as server + from vllm_mlx.engine.base import GenerationOutput + + extract_calls = {"count": 0} + + class FakeEngine: + model_name = "fake-engine" + is_mllm = False + preserve_native_tool_format = False + + async def chat(self, messages, **kwargs): + return GenerationOutput( + text="done", + prompt_tokens=3, + completion_tokens=1, + finish_reason="stop", + ) + + def fake_extract(messages, preserve_native_format=False): + extract_calls["count"] += 1 + return ([{"role": "user", "content": "hi"}], [], []) + + monkeypatch.setattr(server, "_engine", FakeEngine()) + monkeypatch.setattr(server, "_model_name", "test-model") + monkeypatch.setattr(server, "_default_timeout", 30.0) + monkeypatch.setattr(server, "_default_max_tokens", 128) + monkeypatch.setattr(server, "_api_key", None) + monkeypatch.setattr( + server, + "_rate_limiter", + server.RateLimiter(requests_per_minute=60, enabled=False), + ) + monkeypatch.setattr(server, "extract_multimodal_content", fake_extract) + monkeypatch.setattr(server, "_reasoning_parser", None) + + response = client.post( + "/v1/messages", + json={ + "model": "test-model", + "max_tokens": 8, + "messages": [{"role": "user", "content": "Hello"}], + }, + ) + + assert response.status_code == 200 + assert extract_calls["count"] == 1 + + # ============================================================================= # Integration Tests (require running server) # ============================================================================= @@ -994,6 +2475,171 @@ def test_chat_completion(self, server_url): assert data["choices"][0]["message"]["content"] +class TestSseDoneTermination: + """Regression tests for SSE data: [DONE] termination signal. + + Covers #101: streaming responses must always emit exactly one + data: [DONE] event, even when the engine raises mid-stream. + """ + + @pytest.mark.anyio + async def test_stream_completion_normal_emits_done(self, monkeypatch): + """Normal stream_completion yields exactly one [DONE] at the end.""" + from vllm_mlx.api.models import CompletionRequest + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.server import stream_completion + + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-engine" + + async def stream_generate(self, **kwargs): + yield GenerationOutput(text="Hello", new_text="Hello", finished=False) + yield GenerationOutput( + text="Hello world", + new_text=" world", + finished=True, + finish_reason="stop", + prompt_tokens=5, + completion_tokens=2, + ) + + monkeypatch.setattr(server, "_model_name", "test-model") + monkeypatch.setattr(server, "_default_max_tokens", 100) + + request = CompletionRequest(model="test-model", prompt="Say hello") + chunks = [ + chunk + async for chunk in stream_completion( + FakeEngine(), + "Say hello", + request, + max_tokens=server._default_max_tokens, + ) + ] + + done_chunks = [c for c in chunks if c == "data: [DONE]\n\n"] + assert ( + len(done_chunks) == 1 + ), f"Expected exactly 1 [DONE], got {len(done_chunks)}" + assert chunks[-1] == "data: [DONE]\n\n", "[DONE] must be the last chunk" + + @pytest.mark.anyio + async def test_stream_completion_exception_still_emits_done(self, monkeypatch): + """When engine raises mid-stream, [DONE] is still emitted via _ensure_sse_terminal.""" + from vllm_mlx.api.models import CompletionRequest + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.server import _ensure_sse_terminal, stream_completion + + import vllm_mlx.server as server + + class ExplodingEngine: + model_name = "exploding-engine" + + async def stream_generate(self, **kwargs): + yield GenerationOutput( + text="partial", new_text="partial", finished=False + ) + raise RuntimeError("Metal command buffer SIGABRT") + + monkeypatch.setattr(server, "_model_name", "test-model") + monkeypatch.setattr(server, "_default_max_tokens", 100) + + request = CompletionRequest(model="test-model", prompt="Say hello") + # Wrap with _ensure_sse_terminal, matching server routing + chunks = [ + chunk + async for chunk in _ensure_sse_terminal( + stream_completion( + ExplodingEngine(), + "Say hello", + request, + max_tokens=server._default_max_tokens, + ), + "data: [DONE]\n\n", + ) + ] + + done_chunks = [c for c in chunks if c == "data: [DONE]\n\n"] + assert ( + len(done_chunks) == 1 + ), f"Expected exactly 1 [DONE], got {len(done_chunks)}" + assert chunks[-1] == "data: [DONE]\n\n", "[DONE] must be the last chunk" + + @pytest.mark.anyio + async def test_ensure_sse_terminal_normal_no_duplicate(self): + """Wrapper passes through the generator's own [DONE] without duplicating.""" + from vllm_mlx.server import _ensure_sse_terminal + + async def happy_generator(): + yield "data: {}\n\n" + yield "data: [DONE]\n\n" + + chunks = [ + chunk + async for chunk in _ensure_sse_terminal( + happy_generator(), "data: [DONE]\n\n" + ) + ] + + done_chunks = [c for c in chunks if c == "data: [DONE]\n\n"] + assert ( + len(done_chunks) == 1 + ), f"Expected exactly 1 [DONE], got {len(done_chunks)}" + + @pytest.mark.anyio + async def test_ensure_sse_terminal_exception_emits_done(self): + """Wrapper emits [DONE] when inner generator raises before reaching it.""" + from vllm_mlx.server import _ensure_sse_terminal + + async def exploding_generator(): + yield "data: {}\n\n" + raise RuntimeError("engine crashed") + + chunks = [ + chunk + async for chunk in _ensure_sse_terminal( + exploding_generator(), "data: [DONE]\n\n" + ) + ] + + done_chunks = [c for c in chunks if c == "data: [DONE]\n\n"] + assert ( + len(done_chunks) == 1 + ), f"Expected exactly 1 [DONE], got {len(done_chunks)}" + assert chunks[-1] == "data: [DONE]\n\n", "[DONE] must be the last chunk" + + @pytest.mark.anyio + async def test_ensure_sse_terminal_anthropic_protocol(self): + """Wrapper emits Anthropic message_stop, not OpenAI [DONE], on exception.""" + import json + + from vllm_mlx.server import _ensure_sse_terminal + + anthropic_terminal = ( + f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'})}\n\n" + ) + + async def exploding_anthropic_stream(): + yield "event: content_block_delta\ndata: {}\n\n" + raise RuntimeError("engine crashed") + + chunks = [ + chunk + async for chunk in _ensure_sse_terminal( + exploding_anthropic_stream(), anthropic_terminal + ) + ] + + # Must emit Anthropic terminal, NOT OpenAI [DONE] + assert chunks[-1] == anthropic_terminal + openai_done = [c for c in chunks if c == "data: [DONE]\n\n"] + assert ( + len(openai_done) == 0 + ), "Must not emit OpenAI [DONE] for Anthropic streams" + + def pytest_addoption(parser): """Add custom command line options.""" parser.addoption( diff --git a/tests/test_server_cache_controls.py b/tests/test_server_cache_controls.py new file mode 100644 index 000000000..a551051ac --- /dev/null +++ b/tests/test_server_cache_controls.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for cache control endpoints.""" + +import sys +import types + +from fastapi.testclient import TestClient + + +def test_cache_stats_includes_engine_cache(monkeypatch): + import vllm_mlx.server as server + + fake_utils = types.ModuleType("mlx_vlm.utils") + fake_utils.get_multimodal_kv_cache_stats = lambda: {"entries": 1} + fake_utils.get_pixel_values_cache_stats = lambda: {"entries": 2} + fake_utils.get_pil_cache_stats = lambda: {"entries": 3} + + class DummyEngine: + def get_cache_stats(self): + return {"prefix_cache": {"hits": 7, "misses": 2}} + + original_engine = server._engine + original_api_key = server._api_key + original_module = sys.modules.get("mlx_vlm.utils") + try: + server._engine = DummyEngine() + server._api_key = None + sys.modules["mlx_vlm.utils"] = fake_utils + client = TestClient(server.app) + + response = client.get("/v1/cache/stats") + assert response.status_code == 200 + assert response.json()["engine_cache"] == { + "prefix_cache": {"hits": 7, "misses": 2} + } + finally: + server._engine = original_engine + server._api_key = original_api_key + if original_module is not None: + sys.modules["mlx_vlm.utils"] = original_module + else: + sys.modules.pop("mlx_vlm.utils", None) + + +def test_clear_cache_clears_engine_managed_runtime_caches(monkeypatch): + import vllm_mlx.server as server + + calls = {"multimodal": 0, "pixel": 0, "engine": 0} + fake_utils = types.ModuleType("mlx_vlm.utils") + + def clear_multimodal(): + calls["multimodal"] += 1 + + def clear_pixel(): + calls["pixel"] += 1 + + fake_utils.clear_multimodal_kv_cache = clear_multimodal + fake_utils.clear_pixel_values_cache = clear_pixel + + class DummyEngine: + def clear_runtime_caches(self): + calls["engine"] += 1 + return {"prefix_cache": True} + + original_engine = server._engine + original_api_key = server._api_key + original_module = sys.modules.get("mlx_vlm.utils") + try: + server._engine = DummyEngine() + server._api_key = None + sys.modules["mlx_vlm.utils"] = fake_utils + client = TestClient(server.app) + + response = client.delete("/v1/cache") + assert response.status_code == 200 + assert response.json()["engine_cache"] == {"prefix_cache": True} + assert calls == {"multimodal": 1, "pixel": 1, "engine": 1} + finally: + server._engine = original_engine + server._api_key = original_api_key + if original_module is not None: + sys.modules["mlx_vlm.utils"] = original_module + else: + sys.modules.pop("mlx_vlm.utils", None) diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index fe95c3cf0..ac0f508b3 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -4,6 +4,7 @@ import asyncio from unittest.mock import MagicMock, patch +import mlx.core as mx import pytest pytestmark = pytest.mark.anyio @@ -268,6 +269,95 @@ async def test_engine_initialization_creates_lock(self): assert hasattr(engine, "_generation_lock") assert isinstance(engine._generation_lock, asyncio.Lock) + @pytest.mark.anyio + async def test_run_blocking_serialized_rebinds_worker_generation_streams(self): + """Worker-thread MLX generation should get fresh thread-local streams.""" + import importlib + + from vllm_mlx.engine.simple import SimpleEngine + + mlx_lm_generate = importlib.import_module("mlx_lm.generate") + sentinel_stream = object() + + with ( + patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False), + patch("vllm_mlx.mlx_streams.mx.default_device", return_value="gpu"), + patch( + "vllm_mlx.mlx_streams.mx.new_stream", + return_value=sentinel_stream, + ), + patch("vllm_mlx.mlx_streams.mx.set_default_stream"), + ): + engine = SimpleEngine("test-model") + observed = await engine._run_blocking_serialized( + lambda: mlx_lm_generate.generation_stream + ) + + assert observed is sentinel_stream + + @pytest.mark.anyio + async def test_start_keeps_text_routing_for_mllm_without_mtp(self): + """MLLM text-only routing must stay available when MTP is disabled.""" + from vllm_mlx.engine.simple import SimpleEngine + + text_model = MagicMock() + text_model.mtp = None + tokenizer = MagicMock() + tokenizer.convert_tokens_to_ids.return_value = 42 + + mock_mllm = MagicMock() + mock_mllm.model = MagicMock() + mock_mllm.get_tokenizer.return_value = tokenizer + + with ( + patch( + "vllm_mlx.models.mllm.MLXMultimodalLM", + return_value=mock_mllm, + ), + patch( + "vllm_mlx.text_model_from_vlm.build_text_model", + return_value=text_model, + ), + ): + engine = SimpleEngine("qwen3.6-27b", force_mllm=True, mtp=False) + await engine.start() + + assert engine._text_model is text_model + assert engine._text_tokenizer is tokenizer + + @pytest.mark.anyio + async def test_mllm_nonstream_text_only_routes_without_mtp(self): + """Non-stream text-only MLLM chat must aggregate the TextModel route.""" + from vllm_mlx.engine.simple import SimpleEngine + + async def fake_stream_chat(*args, **kwargs): + yield MagicMock( + text="Hello", + tokens=[1], + prompt_tokens=5, + completion_tokens=1, + finish_reason="stop", + finished=True, + ) + + engine = SimpleEngine("test-model", force_mllm=True, mtp=False) + engine._loaded = True + engine._text_model = MagicMock() + engine._model = MagicMock() + engine.stream_chat = fake_stream_chat # type: ignore[method-assign] + + output = await engine.chat( + messages=[{"role": "user", "content": "hello"}], + max_tokens=16, + ) + + assert output.text == "Hello" + assert output.tokens == [1] + assert output.prompt_tokens == 5 + assert output.completion_tokens == 1 + assert output.finish_reason == "stop" + engine._model.chat.assert_not_called() + @pytest.mark.anyio async def test_requests_complete_in_order(self, mock_model): """Test that concurrent requests complete (may be in any order due to lock).""" @@ -291,7 +381,7 @@ async def test_requests_complete_in_order(self, mock_model): for result in results: assert result.text == "test response" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_generate_accumulates_over_stream_generate(self): """generate() should iterate stream_generate() and return the last yielded GenerationOutput, forwarding per-request kwargs (including @@ -353,7 +443,7 @@ async def fake_stream_generate(**kwargs): assert captured_kwargs.get("specprefill") is True assert captured_kwargs.get("specprefill_keep_pct") == 0.2 - @pytest.mark.asyncio + @pytest.mark.anyio async def test_generate_empty_stream_returns_safe_default(self): """If stream_generate yields nothing, generate() returns an empty stop-reason GenerationOutput rather than raising. @@ -373,3 +463,794 @@ async def empty_stream_generate(**kwargs): assert output.text == "" assert output.finish_reason == "stop" + + def test_seed_logits_processors_prepends_prompt_tokens(self): + """Continuation decode processors must see the original prompt prefix.""" + from vllm_mlx.engine.simple import _seed_logits_processors + + seen = {} + + def processor(tokens, logits): + seen["tokens"] = tokens.tolist() + return logits + + seeded = _seed_logits_processors( + mx.array([10, 11], dtype=mx.uint32), [processor] + ) + + logits = mx.zeros((1, 8), dtype=mx.float32) + seeded[0](mx.array([12, 13], dtype=mx.uint32), logits) + + assert seen["tokens"] == [10, 11, 12, 13] + + @pytest.mark.anyio + async def test_specprefill_success_preserves_mtp_path(self): + """Successful sparse prefill should continue through the normal MTP path.""" + from types import SimpleNamespace + + from vllm_mlx.engine.simple import SimpleEngine + + captured = {} + + def fake_make_sampler(**kwargs): + captured["sampler_kwargs"] = kwargs + + def _sample(_logprobs): + return mx.array([17], dtype=mx.uint32) + + return _sample + + def fake_stream_generate(model, tokenizer, prompt, **kwargs): + captured["prompt"] = prompt.tolist() + captured["kwargs"] = kwargs + yield SimpleNamespace(text="B", finish_reason="stop") + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = "<|im_start|>user\nhello" + tokenizer.bos_token = None + tokenizer.eos_token_id = 99 + tokenizer.encode.return_value = [5, 6, 7] + tokenizer.decode.side_effect = lambda ids: "".join( + {17: "A", 99: ""}.get(tok, f"<{tok}>") for tok in ids + ) + + text_model = MagicMock() + text_model.mtp = object() + text_model.make_mtp_cache.return_value = ["mtp-cache"] + + engine = SimpleEngine( + "test-model", + force_mllm=True, + mtp=True, + mtp_num_draft_tokens=4, + specprefill_enabled=True, + specprefill_threshold=1, + ) + engine._loaded = True + engine._text_model = text_model + engine._text_tokenizer = tokenizer + engine._draft_model = object() + + with ( + patch("vllm_mlx.engine.simple._bind_worker_generation_streams"), + patch( + "mlx_lm.models.cache.make_prompt_cache", + return_value=["backbone-cache"], + ), + patch("mlx_lm.sample_utils.make_sampler", side_effect=fake_make_sampler), + patch( + "mlx_lm.sample_utils.make_logits_processors", + return_value=[], + ), + patch("mlx_lm.stream_generate", side_effect=fake_stream_generate), + patch( + "vllm_mlx.specprefill.score_tokens", + return_value=mx.array([1.0, 0.9, 0.8], dtype=mx.float32), + ), + patch( + "vllm_mlx.specprefill.select_chunks", + return_value=mx.array([0, 1, 2], dtype=mx.int32), + ), + patch( + "vllm_mlx.specprefill.sparse_prefill", + return_value=mx.zeros((1, 3, 32), dtype=mx.float32), + ), + patch("vllm_mlx.specprefill.cleanup_rope"), + ): + outputs = [ + chunk + async for chunk in engine._stream_generate_text( + messages=[{"role": "user", "content": "hello"}], + max_tokens=4, + temperature=0.6, + top_p=0.95, + ) + ] + + assert [chunk.new_text for chunk in outputs] == ["A", "B"] + assert captured["sampler_kwargs"] == { + "temp": 0.6, + "top_p": 0.95, + "top_k": 0, + "min_p": 0.0, + } + assert captured["prompt"] == [17] + assert captured["kwargs"]["mtp"] is True + assert captured["kwargs"]["prompt_cache"] == ["backbone-cache", "mtp-cache"] + assert captured["kwargs"]["max_tokens"] == 3 + assert captured["kwargs"]["logits_processors"] is None + + @pytest.mark.anyio + async def test_stream_generate_text_forwards_logits_processors_and_sampler_args( + self, + ): + """Text routing must preserve request-local decoding controls.""" + from types import SimpleNamespace + + from vllm_mlx.engine.simple import SimpleEngine + + captured_kwargs = {} + sampler_calls = [] + penalty_calls = [] + user_processor = MagicMock() + penalty_processor = MagicMock() + + def fake_stream_generate(model, tokenizer, prompt, **kwargs): + captured_kwargs.update(kwargs) + yield SimpleNamespace(text="Hello", finish_reason="stop") + + def fake_make_sampler(**kwargs): + sampler_calls.append(kwargs) + return MagicMock() + + def fake_make_logits_processors(**kwargs): + penalty_calls.append(kwargs) + return [penalty_processor] + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = "<|im_start|>user\nhello" + tokenizer.bos_token = None + tokenizer.eos_token_id = 42 + + engine = SimpleEngine("test-model", force_mllm=True, mtp=False) + engine._loaded = True + engine._text_model = MagicMock() + engine._text_model.mtp = None + engine._text_tokenizer = tokenizer + + with ( + patch("mlx_lm.stream_generate", side_effect=fake_stream_generate), + patch("mlx_lm.sample_utils.make_sampler", side_effect=fake_make_sampler), + patch( + "mlx_lm.sample_utils.make_logits_processors", + side_effect=fake_make_logits_processors, + ), + ): + outputs = [ + chunk + async for chunk in engine._stream_generate_text( + messages=[{"role": "user", "content": "hello"}], + max_tokens=16, + temperature=0.3, + top_p=0.8, + top_k=40, + min_p=0.1, + presence_penalty=1.5, + repetition_penalty=1.2, + logits_processors=[user_processor], + ) + ] + + assert outputs[-1].text == "Hello" + assert sampler_calls == [{"temp": 0.3, "top_p": 0.8, "top_k": 40, "min_p": 0.1}] + assert penalty_calls == [{"repetition_penalty": 1.2, "presence_penalty": 1.5}] + assert captured_kwargs["logits_processors"] == [ + user_processor, + penalty_processor, + ] + + @pytest.mark.anyio + async def test_stream_generate_text_disables_mtp_when_logits_processors_active( + self, + ): + """Custom logits processors must fail closed to non-MTP decoding.""" + from types import SimpleNamespace + + from vllm_mlx.engine.simple import SimpleEngine + + captured_kwargs = {} + user_processor = MagicMock() + + def fake_stream_generate(model, tokenizer, prompt, **kwargs): + captured_kwargs.update(kwargs) + yield SimpleNamespace(text="Hello", finish_reason="stop") + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = "<|im_start|>user\nhello" + tokenizer.bos_token = None + tokenizer.eos_token_id = 42 + + engine = SimpleEngine("test-model", force_mllm=True, mtp=True) + engine._loaded = True + engine._text_model = MagicMock() + engine._text_model.mtp = MagicMock() + engine._text_tokenizer = tokenizer + + with patch("mlx_lm.stream_generate", side_effect=fake_stream_generate): + outputs = [ + chunk + async for chunk in engine._stream_generate_text( + messages=[{"role": "user", "content": "hello"}], + max_tokens=16, + temperature=0.7, + top_p=0.9, + logits_processors=[user_processor], + ) + ] + + assert outputs[-1].text == "Hello" + assert "mtp" not in captured_kwargs + assert captured_kwargs["logits_processors"][0] is user_processor + + @pytest.mark.anyio + async def test_stream_generate_text_disables_mtp_for_thinking_processor( + self, + ): + """Thinking-budget processors must fail closed to non-MTP decoding.""" + from types import SimpleNamespace + + from vllm_mlx.engine.simple import SimpleEngine + + captured_kwargs = {} + + def fake_stream_generate(model, tokenizer, prompt, **kwargs): + captured_kwargs.update(kwargs) + yield SimpleNamespace(text="Hello", finish_reason="stop") + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = "<|im_start|>user\nhello" + tokenizer.bos_token = None + tokenizer.eos_token_id = 42 + + engine = SimpleEngine("test-model", force_mllm=True, mtp=True) + engine._loaded = True + engine._text_model = MagicMock() + engine._text_model.mtp = MagicMock() + engine._text_tokenizer = tokenizer + + thinking_proc = MagicMock() + + with patch("mlx_lm.stream_generate", side_effect=fake_stream_generate): + outputs = [ + chunk + async for chunk in engine._stream_generate_text( + messages=[{"role": "user", "content": "hello"}], + max_tokens=16, + temperature=0.7, + top_p=0.9, + logits_processors=[thinking_proc], + ) + ] + + assert outputs[-1].text == "Hello" + assert "mtp" not in captured_kwargs + assert captured_kwargs["logits_processors"][0] is thinking_proc + + @pytest.mark.anyio + async def test_stream_generate_text_passes_num_draft_tokens(self): + """Text routing should forward configured MTP draft depth.""" + from types import SimpleNamespace + + from vllm_mlx.engine.simple import SimpleEngine + + captured_kwargs = {} + + def fake_stream_generate(model, tokenizer, prompt, **kwargs): + captured_kwargs.update(kwargs) + yield SimpleNamespace(text="Hello", finish_reason="stop") + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = "<|im_start|>user\nhello" + tokenizer.bos_token = None + tokenizer.eos_token_id = 42 + + engine = SimpleEngine( + "test-model", + force_mllm=True, + mtp=True, + mtp_num_draft_tokens=4, + ) + engine._loaded = True + engine._text_model = MagicMock() + engine._text_model.mtp = MagicMock() + engine._text_tokenizer = tokenizer + + with patch("mlx_lm.stream_generate", side_effect=fake_stream_generate): + outputs = [ + chunk + async for chunk in engine._stream_generate_text( + messages=[{"role": "user", "content": "hello"}], + max_tokens=16, + temperature=0.7, + top_p=0.9, + ) + ] + + assert outputs[-1].text == "Hello" + assert captured_kwargs["mtp"] is True + assert captured_kwargs["num_draft_tokens"] == 4 + + @pytest.mark.anyio + async def test_stream_generate_text_reenables_mtp_after_retired_processor_when_enabled( + self, + ): + """Retired thinking processor handoff is an explicit opt-in path.""" + from types import SimpleNamespace + + from vllm_mlx.engine.simple import SimpleEngine + + calls = [] + + class RetiringProcessor: + def __init__(self): + self.is_retired = False + + def __call__(self, tokens, logits): + return logits + + processor = RetiringProcessor() + + def fake_stream_generate(model, tokenizer, prompt, **kwargs): + calls.append({"prompt": prompt, **kwargs}) + if len(calls) == 1: + processor.is_retired = True + yield SimpleNamespace(token=11, text="Hello", finish_reason=None) + else: + yield SimpleNamespace(token=12, text=" world", finish_reason="stop") + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = "<|im_start|>user\nhello" + tokenizer.bos_token = None + tokenizer.eos_token_id = 42 + tokenizer.encode.return_value = [11] + + engine = SimpleEngine( + "test-model", + force_mllm=True, + mtp=True, + mtp_num_draft_tokens=4, + ) + engine._loaded = True + engine._text_model = MagicMock() + engine._text_model.mtp = MagicMock() + engine._text_model.make_mtp_cache.return_value = [] + engine._text_tokenizer = tokenizer + + with ( + patch.dict( + "os.environ", + {"VLLM_MLX_ENABLE_THINKING_RETIREMENT_RESUME": "1"}, + ), + patch("mlx_lm.stream_generate", side_effect=fake_stream_generate), + patch("mlx_lm.models.cache.make_prompt_cache", return_value=[]), + patch("mlx_lm.models.cache.trim_prompt_cache"), + ): + outputs = [ + chunk + async for chunk in engine._stream_generate_text( + messages=[{"role": "user", "content": "hello"}], + max_tokens=16, + temperature=0.7, + top_p=0.9, + logits_processors=[processor], + ) + ] + + assert outputs[-1].text == "Hello world" + assert len(calls) == 2 + assert "mtp" not in calls[0] + assert calls[0]["logits_processors"][0] is processor + assert calls[1]["mtp"] is True + assert calls[1]["num_draft_tokens"] == 4 + assert "logits_processors" not in calls[1] + + @pytest.mark.anyio + async def test_stream_generate_text_specprefill_reenables_mtp_after_retirement( + self, + ): + """SpecPrefill retirement-to-MTP continuation is explicit opt-in.""" + from types import SimpleNamespace + + from vllm_mlx.engine.simple import SimpleEngine + + calls = [] + + class RetiringProcessor: + def __init__(self): + self.is_retired = False + + def __call__(self, tokens, logits): + return logits + + processor = RetiringProcessor() + + def fake_stream_generate(model, tokenizer, prompt, **kwargs): + calls.append({"prompt": prompt, **kwargs}) + yield SimpleNamespace(token=12, text=" world", finish_reason="stop") + + tokenizer = MagicMock() + tokenizer.apply_chat_template.return_value = "<|im_start|>user\nhello" + tokenizer.bos_token = None + tokenizer.eos_token_id = 42 + tokenizer.encode.return_value = [1, 2, 3, 4] + tokenizer.decode.side_effect = lambda toks: "Hello" if toks == [11] else "" + + engine = SimpleEngine( + "test-model", + force_mllm=True, + mtp=True, + mtp_num_draft_tokens=4, + specprefill_enabled=True, + ) + engine._loaded = True + engine._draft_model = MagicMock() + engine._text_model = MagicMock() + engine._text_model.mtp = MagicMock() + engine._text_model.make_mtp_cache.return_value = [] + engine._text_tokenizer = tokenizer + + def fake_sample(tokens, logits, sampler, logits_processors): + processor.is_retired = True + return mx.array(11, dtype=mx.uint32), logits + + with ( + patch.dict( + "os.environ", + {"VLLM_MLX_ENABLE_THINKING_RETIREMENT_RESUME": "1"}, + ), + patch("mlx_lm.stream_generate", side_effect=fake_stream_generate), + patch("mlx_lm.models.cache.make_prompt_cache", return_value=[]), + patch( + "vllm_mlx.specprefill.score_tokens", return_value=mx.array([0.1, 0.2]) + ), + patch("vllm_mlx.specprefill.select_chunks", return_value=mx.array([0, 1])), + patch( + "vllm_mlx.specprefill.sparse_prefill", + return_value=mx.zeros((1, 1, 32)), + ), + patch("vllm_mlx.specprefill.cleanup_rope"), + patch( + "vllm_mlx.engine.simple._sample_with_processors", + side_effect=fake_sample, + ), + ): + outputs = [ + chunk + async for chunk in engine._stream_generate_text( + messages=[{"role": "user", "content": "hello"}], + max_tokens=16, + temperature=0.7, + top_p=0.9, + specprefill=True, + logits_processors=[processor], + ) + ] + + assert outputs[-1].text == "Hello world" + assert len(calls) == 1 + assert calls[0]["mtp"] is True + assert calls[0]["num_draft_tokens"] == 4 + assert "logits_processors" not in calls[0] + + @pytest.mark.anyio + async def test_cancellation_does_not_release_lock_before_worker_finishes( + self, mock_llm_model + ): + """A cancelled blocking chat call must not overlap the next worker.""" + from threading import Event, Lock + + from vllm_mlx.engine.simple import SimpleEngine + + first_started = Event() + release_workers = Event() + call_count = 0 + call_lock = Lock() + + def chat_side_effect(**kwargs): + nonlocal call_count + with call_lock: + call_count += 1 + current_call = call_count + mock_llm_model._concurrent_count += 1 + mock_llm_model._max_concurrent = max( + mock_llm_model._max_concurrent, + mock_llm_model._concurrent_count, + ) + if current_call == 1: + first_started.set() + + try: + release_workers.wait(timeout=1.0) + result = MagicMock() + result.text = f"response-{current_call}" + result.tokens = [1, 2, 3] + result.finish_reason = "stop" + return result + finally: + with call_lock: + mock_llm_model._concurrent_count -= 1 + + mock_llm_model.chat = MagicMock(side_effect=chat_side_effect) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = mock_llm_model + engine._loaded = True + + task1 = asyncio.create_task( + engine.chat( + messages=[{"role": "user", "content": "first"}], max_tokens=8 + ) + ) + await asyncio.to_thread(first_started.wait, 1.0) + + task1.cancel() + task2 = asyncio.create_task( + engine.chat( + messages=[{"role": "user", "content": "second"}], max_tokens=8 + ) + ) + + await asyncio.sleep(0.05) + release_workers.set() + + with pytest.raises(asyncio.CancelledError): + await task1 + result2 = await task2 + + assert result2.text == "response-2" + assert mock_llm_model._max_concurrent == 1 + + @pytest.mark.anyio + async def test_specprefill_path_does_not_prelock_serialized_runner(self): + """SpecPrefill should let _run_blocking_serialized own the lock.""" + from vllm_mlx.engine.simple import SimpleEngine + + async def fake_serialized(func, *args, **kwargs): + assert not engine._generation_lock.locked() + return [] + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._loaded = True + engine._model = MagicMock() + engine._model.model = MagicMock() + engine._model.tokenizer = MagicMock() + engine._draft_model = MagicMock() + engine._run_blocking_serialized = fake_serialized # type: ignore[method-assign] + + outputs = [] + async for chunk in engine._stream_generate_specprefill( + prompt="hello", + tokens=[1, 2, 3, 4], + max_tokens=4, + temperature=0.7, + top_p=0.9, + ): + outputs.append(chunk) + + assert len(outputs) == 1 + assert outputs[0].finished + assert outputs[0].completion_tokens == 0 + + @pytest.mark.anyio + async def test_text_mtp_path_does_not_prelock_serialized_runner(self): + """Text-only MTP path should let _run_blocking_serialized own the lock.""" + from vllm_mlx.engine.simple import SimpleEngine + + async def fake_serialized(func, *args, **kwargs): + assert not engine._generation_lock.locked() + return [] + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=True): + engine = SimpleEngine("test-model") + engine._loaded = True + engine._text_model = MagicMock() + engine._text_model.make_mtp_cache = MagicMock(return_value=[]) + engine._text_tokenizer = MagicMock() + engine._text_tokenizer.apply_chat_template = MagicMock(return_value="hello") + engine._text_tokenizer.bos_token = None + engine._draft_model = None + engine._run_blocking_serialized = fake_serialized # type: ignore[method-assign] + + outputs = [] + async for chunk in engine._stream_generate_text( + messages=[{"role": "user", "content": "hello"}], + max_tokens=4, + temperature=0.7, + top_p=0.9, + ): + outputs.append(chunk) + + assert len(outputs) == 1 + assert outputs[0].finished + assert outputs[0].completion_tokens == 0 + + @pytest.mark.anyio + async def test_specprefill_threads_same_cancel_check_to_helpers(self): + """SpecPrefill worker should pass one cooperative cancel hook through both phases.""" + from vllm_mlx.engine.simple import SimpleEngine + + captured = {} + + def fake_score_tokens(*args, cancel_check=None, **kwargs): + captured["score"] = cancel_check + return mx.array([0.5], dtype=mx.float32) + + def fake_sparse_prefill(*args, cancel_check=None, **kwargs): + captured["prefill"] = cancel_check + return mx.zeros((1, 1, 8), dtype=mx.float32) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._loaded = True + engine._draft_model = MagicMock() + engine._model = MagicMock() + engine._model.model = MagicMock() + engine._model.tokenizer = MagicMock() + engine._model.tokenizer.decode = MagicMock(return_value="A") + engine._model.tokenizer.eos_token_id = 0 + + outputs = [] + with ( + patch("mlx_lm.models.cache.make_prompt_cache", return_value=[]), + patch( + "mlx_lm.sample_utils.make_sampler", + return_value=lambda logits: mx.array([0], dtype=mx.int32), + ), + patch( + "vllm_mlx.specprefill.score_tokens", side_effect=fake_score_tokens + ), + patch( + "vllm_mlx.specprefill.select_chunks", + return_value=mx.array([0], dtype=mx.int32), + ), + patch( + "vllm_mlx.specprefill.sparse_prefill", + side_effect=fake_sparse_prefill, + ), + patch("vllm_mlx.specprefill.cleanup_rope"), + ): + async for chunk in engine._stream_generate_specprefill( + prompt="hello", + tokens=[1, 2, 3, 4], + max_tokens=4, + temperature=0.7, + top_p=0.9, + ): + outputs.append(chunk.new_text) + + assert outputs == ["A"] + assert callable(captured["score"]) + assert captured["score"] is captured["prefill"] + + @pytest.mark.anyio + async def test_cancelling_specprefill_request_stops_during_scoring(self): + """Cancelling SpecPrefill should signal the blocking scorer and exit without output.""" + import time + from threading import Event + + from vllm_mlx.engine.simple import SimpleEngine, _SpecPrefillCancelled + + score_started = Event() + score_cancelled = Event() + + def fake_score_tokens(*args, cancel_check=None, **kwargs): + assert callable(cancel_check) + score_started.set() + while True: + try: + cancel_check() + except _SpecPrefillCancelled: + score_cancelled.set() + raise + time.sleep(0.01) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._loaded = True + engine._draft_model = MagicMock() + engine._model = MagicMock() + engine._model.model = MagicMock() + engine._model.tokenizer = MagicMock() + + async def consume(): + async for _chunk in engine._stream_generate_specprefill( + prompt="hello", + tokens=[1, 2, 3, 4], + max_tokens=4, + temperature=0.7, + top_p=0.9, + ): + pytest.fail("Cancelled SpecPrefill request should not emit output") + + with ( + patch("mlx_lm.models.cache.make_prompt_cache", return_value=[]), + patch( + "vllm_mlx.specprefill.score_tokens", + side_effect=fake_score_tokens, + ), + patch("vllm_mlx.specprefill.cleanup_rope"), + ): + task = asyncio.create_task(consume()) + assert await asyncio.to_thread(score_started.wait, 1.0) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert await asyncio.to_thread(score_cancelled.wait, 1.0) + + @pytest.mark.anyio + async def test_cancelling_specprefill_request_stops_during_sparse_prefill(self): + """Cancelling SpecPrefill should signal the sparse-prefill loop and exit without output.""" + import time + from threading import Event + + from vllm_mlx.engine.simple import SimpleEngine, _SpecPrefillCancelled + + prefill_started = Event() + prefill_cancelled = Event() + + def fake_sparse_prefill(*args, cancel_check=None, **kwargs): + assert callable(cancel_check) + prefill_started.set() + while True: + try: + cancel_check() + except _SpecPrefillCancelled: + prefill_cancelled.set() + raise + time.sleep(0.01) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._loaded = True + engine._draft_model = MagicMock() + engine._model = MagicMock() + engine._model.model = MagicMock() + engine._model.tokenizer = MagicMock() + + async def consume(): + async for _chunk in engine._stream_generate_specprefill( + prompt="hello", + tokens=[1, 2, 3, 4], + max_tokens=4, + temperature=0.7, + top_p=0.9, + ): + pytest.fail("Cancelled SpecPrefill request should not emit output") + + with ( + patch("mlx_lm.models.cache.make_prompt_cache", return_value=[]), + patch( + "vllm_mlx.specprefill.score_tokens", + return_value=mx.array([0.5], dtype=mx.float32), + ), + patch( + "vllm_mlx.specprefill.select_chunks", + return_value=mx.array([0], dtype=mx.int32), + ), + patch( + "vllm_mlx.specprefill.sparse_prefill", + side_effect=fake_sparse_prefill, + ), + patch("vllm_mlx.specprefill.cleanup_rope"), + ): + task = asyncio.create_task(consume()) + assert await asyncio.to_thread(prefill_started.wait, 1.0) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + assert await asyncio.to_thread(prefill_cancelled.wait, 1.0) diff --git a/tests/test_simple_engine_cancel_serialization.py b/tests/test_simple_engine_cancel_serialization.py index 28c25868e..e5434c36a 100644 --- a/tests/test_simple_engine_cancel_serialization.py +++ b/tests/test_simple_engine_cancel_serialization.py @@ -3,78 +3,14 @@ from __future__ import annotations +import unittest import asyncio import threading -import unittest +from types import SimpleNamespace from unittest.mock import MagicMock, patch class SimpleEngineCancelSerializationTests(unittest.IsolatedAsyncioTestCase): - async def test_cancellation_does_not_release_lock_before_worker_finishes(self): - """A cancelled request must not let a second MLX worker overlap.""" - from vllm_mlx.engine.simple import SimpleEngine - - model = MagicMock() - model.tokenizer = MagicMock() - model.tokenizer.encode = MagicMock(return_value=[1, 2, 3]) - model._concurrent_count = 0 - model._max_concurrent = 0 - - first_started = threading.Event() - release_workers = threading.Event() - call_count = 0 - call_lock = threading.Lock() - - def generate_side_effect(**kwargs): - nonlocal call_count - with call_lock: - call_count += 1 - current_call = call_count - model._concurrent_count += 1 - model._max_concurrent = max( - model._max_concurrent, model._concurrent_count - ) - if current_call == 1: - first_started.set() - - release_workers.wait(timeout=1.0) - - with call_lock: - model._concurrent_count -= 1 - - result = MagicMock() - result.text = f"response-{current_call}" - result.tokens = [1, 2, 3] - result.finish_reason = "stop" - return result - - model.generate = MagicMock(side_effect=generate_side_effect) - - with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): - engine = SimpleEngine("test-model") - engine._model = model - engine._loaded = True - - task1 = asyncio.create_task(engine.generate(prompt="first", max_tokens=8)) - await asyncio.to_thread(first_started.wait, 1.0) - - task1.cancel() - task2 = asyncio.create_task(engine.generate(prompt="second", max_tokens=8)) - - await asyncio.sleep(0.05) - release_workers.set() - - with self.assertRaises(asyncio.CancelledError): - await task1 - result2 = await task2 - - self.assertEqual(result2.text, "response-2") - self.assertEqual( - model._max_concurrent, - 1, - "cancellation released the generation lock before the first worker finished", - ) - async def test_specprefill_path_does_not_prelock_serialized_runner(self): """Specprefill streaming must let _run_blocking_serialized own the lock.""" from vllm_mlx.engine.simple import SimpleEngine @@ -138,6 +74,69 @@ async def fake_serialized(func, *args, **kwargs): self.assertTrue(outputs[0].finished) self.assertEqual(outputs[0].completion_tokens, 0) + async def test_text_route_stream_cancel_stops_after_next_token_boundary(self): + """Client disconnect should not let text-route workers drain max_tokens.""" + import mlx_lm + + from vllm_mlx.engine.simple import SimpleEngine + + second_token_allowed = threading.Event() + second_token_requested = threading.Event() + consumed_tokens = [] + + def fake_stream_generate(*args, **kwargs): + consumed_tokens.append("A") + yield SimpleNamespace(text="A", finish_reason=None) + + second_token_requested.set() + second_token_allowed.wait(timeout=1.0) + consumed_tokens.append("B") + yield SimpleNamespace(text="B", finish_reason=None) + + for token in ("C", "D", "E"): + consumed_tokens.append(token) + yield SimpleNamespace(text=token, finish_reason=None) + + tokenizer = MagicMock() + tokenizer.bos_token = None + tokenizer.apply_chat_template.return_value = "prompt" + tokenizer.encode.return_value = [1, 2, 3] + + with ( + patch("vllm_mlx.engine.simple.is_mllm_model", return_value=True), + patch( + "vllm_mlx.engine.simple._bind_worker_generation_streams", + return_value=None, + ), + patch.object(mlx_lm, "stream_generate", side_effect=fake_stream_generate), + ): + engine = SimpleEngine("test-model") + engine._loaded = True + engine._text_model = MagicMock() + engine._text_model.mtp = None + engine._text_tokenizer = tokenizer + engine._draft_model = None + + stream = engine._stream_generate_text( + messages=[{"role": "user", "content": "hello"}], + max_tokens=8, + temperature=0.7, + top_p=0.9, + ) + first = await stream.__anext__() + self.assertEqual(first.new_text, "A") + + next_task = asyncio.create_task(stream.__anext__()) + await asyncio.to_thread(second_token_requested.wait, 1.0) + next_task.cancel() + await asyncio.sleep(0) + second_token_allowed.set() + + with self.assertRaises(asyncio.CancelledError): + await next_task + + self.assertEqual(consumed_tokens, ["A", "B"]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_ssd_cache.py b/tests/test_ssd_cache.py new file mode 100644 index 000000000..0f2375db6 --- /dev/null +++ b/tests/test_ssd_cache.py @@ -0,0 +1,966 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for SSD KV cache tiering.""" + +import pytest + +from vllm_mlx.ssd_cache import SSDCacheConfig, SSDCacheStats + + +class TestSSDCacheConfig: + """Tests for SSDCacheConfig.""" + + def test_default_config(self): + config = SSDCacheConfig() + assert config.cache_dir is None + assert config.max_size_gb == 10.0 + assert config.max_entries == 10000 + assert config.file_permissions == 0o600 + assert config.dir_permissions == 0o700 + assert config.spill_queue_size == 64 + assert config.retention_seconds is None + + def test_custom_config(self): + config = SSDCacheConfig( + cache_dir="/tmp/test-ssd-cache", + max_size_gb=5.0, + max_entries=500, + ) + assert config.cache_dir == "/tmp/test-ssd-cache" + assert config.max_size_gb == 5.0 + assert config.max_entries == 500 + + def test_max_size_bytes(self): + config = SSDCacheConfig(max_size_gb=2.0) + assert config.max_size_bytes == 2 * 1024 * 1024 * 1024 + + def test_invalid_max_size(self): + with pytest.raises(ValueError, match="max_size_gb"): + SSDCacheConfig(max_size_gb=0.0) + + def test_invalid_max_entries(self): + with pytest.raises(ValueError, match="max_entries"): + SSDCacheConfig(max_entries=0) + + def test_invalid_spill_queue_size(self): + with pytest.raises(ValueError, match="spill_queue_size"): + SSDCacheConfig(spill_queue_size=0) + + +class TestSSDCacheStats: + """Tests for SSDCacheStats.""" + + def test_initial_stats(self): + stats = SSDCacheStats() + assert stats.spill_count == 0 + assert stats.spill_bytes == 0 + assert stats.ssd_hits == 0 + assert stats.ssd_misses == 0 + assert stats.reload_latency_sum == 0.0 + assert stats.reload_bytes == 0 + assert stats.promotion_failures == 0 + + def test_to_dict(self): + stats = SSDCacheStats( + spill_count=10, + spill_bytes=1024 * 1024, + ssd_hits=5, + ssd_misses=3, + reload_latency_sum=0.5, + reload_bytes=512 * 1024, + promotion_failures=1, + ) + d = stats.to_dict() + assert d["spill_count"] == 10 + assert d["spill_bytes"] == 1024 * 1024 + assert d["ssd_hits"] == 5 + assert d["ssd_misses"] == 3 + assert d["reload_bytes"] == 512 * 1024 + assert d["promotion_failures"] == 1 + assert d["ssd_hit_rate"] == pytest.approx(5 / 8) + assert d["avg_reload_latency_ms"] == pytest.approx(100.0) + + def test_hit_rate_no_lookups(self): + stats = SSDCacheStats() + d = stats.to_dict() + assert d["ssd_hit_rate"] == 0.0 + assert d["avg_reload_latency_ms"] == 0.0 + + +import os + +from vllm_mlx.ssd_cache import SSDIndex + + +class TestSSDIndex: + """Tests for SQLite-backed SSD cache index.""" + + @pytest.fixture + def db_dir(self, tmp_path): + return str(tmp_path / "ssd_index_test") + + @pytest.fixture + def index(self, db_dir): + os.makedirs(db_dir, exist_ok=True) + idx = SSDIndex(db_dir) + yield idx + idx.close() + + def test_create_opens_db(self, index, db_dir): + assert os.path.exists(os.path.join(db_dir, "index.db")) + + def test_insert_and_lookup_exact(self, index): + tokens = (1, 2, 3, 4, 5) + index.insert_entry( + tokens_key=tokens, + file_path="entry_abc123.safetensors", + memory_bytes=4096, + num_tokens=5, + ) + result = index.lookup_exact(tokens) + assert result is not None + assert result["file_path"] == "entry_abc123.safetensors" + assert result["memory_bytes"] == 4096 + assert result["num_tokens"] == 5 + + def test_lookup_exact_miss(self, index): + result = index.lookup_exact((99, 98, 97)) + assert result is None + + def test_lookup_prefix(self, index): + # Insert a few entries + index.insert_entry((1, 2, 3), "a.safetensors", 1000, 3) + index.insert_entry((1, 2, 3, 4, 5), "b.safetensors", 2000, 5) + index.insert_entry((1, 2, 3, 4, 5, 6, 7), "c.safetensors", 3000, 7) + index.insert_entry((9, 8, 7), "d.safetensors", 1000, 3) + + # Lookup prefix matches for (1,2,3,4,5,6,7,8) + results = index.lookup_prefix((1, 2, 3, 4, 5, 6, 7, 8)) + # Should return entries whose tokens are a prefix of the query + file_paths = [r["file_path"] for r in results] + assert "a.safetensors" in file_paths + assert "b.safetensors" in file_paths + assert "c.safetensors" in file_paths + assert "d.safetensors" not in file_paths + + def test_delete_entry(self, index): + tokens = (10, 20, 30) + index.insert_entry(tokens, "x.safetensors", 500, 3) + assert index.lookup_exact(tokens) is not None + + index.delete_entry(tokens) + assert index.lookup_exact(tokens) is None + + def test_get_lru(self, index): + import time + + index.insert_entry((1,), "a.safetensors", 100, 1) + time.sleep(0.01) # Ensure different timestamps + index.insert_entry((2,), "b.safetensors", 200, 1) + time.sleep(0.01) + index.insert_entry((3,), "c.safetensors", 300, 1) + + # Get oldest 2 entries + lru = index.get_lru(limit=2) + assert len(lru) == 2 + assert lru[0]["file_path"] == "a.safetensors" + assert lru[1]["file_path"] == "b.safetensors" + + def test_get_total_bytes(self, index): + index.insert_entry((1,), "a.safetensors", 1000, 1) + index.insert_entry((2,), "b.safetensors", 2000, 1) + assert index.get_total_bytes() == 3000 + + def test_get_total_bytes_empty(self, index): + assert index.get_total_bytes() == 0 + + def test_get_entry_count(self, index): + assert index.get_entry_count() == 0 + index.insert_entry((1,), "a.safetensors", 100, 1) + index.insert_entry((2,), "b.safetensors", 200, 1) + assert index.get_entry_count() == 2 + + def test_insert_duplicate_replaces(self, index): + tokens = (1, 2, 3) + index.insert_entry(tokens, "old.safetensors", 100, 3) + index.insert_entry(tokens, "new.safetensors", 200, 3) + result = index.lookup_exact(tokens) + assert result["file_path"] == "new.safetensors" + assert result["memory_bytes"] == 200 + assert index.get_entry_count() == 1 + + def test_touch_updates_access_time(self, index): + import time + + index.insert_entry((1,), "a.safetensors", 100, 1) + time.sleep(0.01) + index.insert_entry((2,), "b.safetensors", 100, 1) + + # Touch the older entry + index.touch((1,)) + + # Now (2,) should be the LRU entry + lru = index.get_lru(limit=1) + assert lru[0]["file_path"] == "b.safetensors" + + +import numpy as np + +from vllm_mlx.ssd_cache import ( + KVCacheSerializer, + ArraysCacheSerializer, + get_serializer_for_layer, + SERIALIZER_SUPPORT_MATRIX, +) + + +class MockMLXArray: + """Minimal mock for MLX array with shape, dtype, and numpy conversion.""" + + def __init__(self, data): + self._data = np.array(data) + self.shape = self._data.shape + self.dtype = type("dtype", (), {"size": self._data.dtype.itemsize})() + + def __array__(self): + return self._data + + +class MockKVCacheLayer: + """Mock KVCache layer with keys, values, offset.""" + + def __init__(self, keys, values, offset): + self.keys = keys + self.values = values + self.offset = offset + + +class MockArraysCacheLayer: + """Mock ArraysCache layer with state list.""" + + def __init__(self, state): + self.state = state + + +class TestLayerSerializer: + """Tests for per-layer serializer interface.""" + + def test_support_matrix_has_required_types(self): + assert "KVCache" in SERIALIZER_SUPPORT_MATRIX + assert "RotatingKVCache" in SERIALIZER_SUPPORT_MATRIX + assert "ArraysCache" in SERIALIZER_SUPPORT_MATRIX + + def test_kv_cache_serializer_round_trip(self, tmp_path): + keys = MockMLXArray(np.random.randn(1, 8, 32, 64).astype(np.float16)) + values = MockMLXArray(np.random.randn(1, 8, 32, 64).astype(np.float16)) + layer = MockKVCacheLayer(keys=keys, values=values, offset=32) + + serializer = KVCacheSerializer() + file_path = str(tmp_path / "layer_0.safetensors") + metadata = serializer.serialize_layer(layer, 0, file_path) + + assert os.path.exists(file_path) + assert metadata["layer_type"] == "KVCache" + assert metadata["offset"] == 32 + + restored = serializer.deserialize_layer(file_path, metadata) + assert restored["offset"] == 32 + np.testing.assert_array_almost_equal( + np.array(restored["keys"]), + np.array(keys), + ) + np.testing.assert_array_almost_equal( + np.array(restored["values"]), + np.array(values), + ) + + def test_arrays_cache_serializer_round_trip(self, tmp_path): + arr0 = MockMLXArray(np.random.randn(1, 64, 128).astype(np.float16)) + arr1 = MockMLXArray(np.random.randn(1, 64, 128).astype(np.float16)) + layer = MockArraysCacheLayer(state=[arr0, arr1]) + + serializer = ArraysCacheSerializer() + file_path = str(tmp_path / "layer_0.safetensors") + metadata = serializer.serialize_layer(layer, 0, file_path) + + assert os.path.exists(file_path) + assert metadata["layer_type"] == "ArraysCache" + assert metadata["num_arrays"] == 2 + + restored = serializer.deserialize_layer(file_path, metadata) + assert len(restored["state"]) == 2 + np.testing.assert_array_almost_equal( + np.array(restored["state"][0]), + np.array(arr0), + ) + + def test_get_serializer_for_kvcache(self): + layer = MockKVCacheLayer(keys=None, values=None, offset=0) + s = get_serializer_for_layer(layer) + assert isinstance(s, KVCacheSerializer) + + def test_get_serializer_for_arrays_cache(self): + layer = MockArraysCacheLayer(state=[]) + s = get_serializer_for_layer(layer) + assert isinstance(s, ArraysCacheSerializer) + + def test_get_serializer_unknown_raises(self): + with pytest.raises(ValueError, match="Unsupported"): + get_serializer_for_layer("not a cache layer") + + +from vllm_mlx.ssd_cache import SSDCacheTier + + +class TestSSDCacheTierCore: + """Tests for SSDCacheTier initialization and directory setup.""" + + def test_init_creates_directory(self, tmp_path): + cache_dir = str(tmp_path / "ssd_tier_test") + config = SSDCacheConfig(cache_dir=cache_dir) + tier = SSDCacheTier(config) + try: + assert os.path.isdir(cache_dir) + assert os.path.exists(os.path.join(cache_dir, "data")) + assert os.path.exists(os.path.join(cache_dir, "index.db")) + finally: + tier.close() + + def test_dir_permissions(self, tmp_path): + cache_dir = str(tmp_path / "ssd_tier_perms") + config = SSDCacheConfig(cache_dir=cache_dir, dir_permissions=0o700) + tier = SSDCacheTier(config) + try: + stat = os.stat(os.path.join(cache_dir, "data")) + # Check owner permissions (at least rwx for owner) + assert stat.st_mode & 0o700 == 0o700 + finally: + tier.close() + + def test_entry_hash_deterministic(self): + tokens = (1, 2, 3, 4, 5) + h1 = SSDCacheTier._entry_hash(tokens) + h2 = SSDCacheTier._entry_hash(tokens) + assert h1 == h2 + assert len(h1) == 64 # SHA-256 hex digest + + def test_entry_hash_different_for_different_tokens(self): + h1 = SSDCacheTier._entry_hash((1, 2, 3)) + h2 = SSDCacheTier._entry_hash((1, 2, 4)) + assert h1 != h2 + + def test_stats_initial(self, tmp_path): + config = SSDCacheConfig(cache_dir=str(tmp_path / "stats_test")) + tier = SSDCacheTier(config) + try: + stats = tier.get_stats() + assert stats["spill_count"] == 0 + assert stats["ssd_hits"] == 0 + finally: + tier.close() + + def test_close_idempotent(self, tmp_path): + config = SSDCacheConfig(cache_dir=str(tmp_path / "close_test")) + tier = SSDCacheTier(config) + tier.close() + tier.close() # Should not raise + + +import time + + +class TestAsyncSpillWriter: + """Tests for the async spill writer thread.""" + + @pytest.fixture + def tier_with_writer(self, tmp_path): + config = SSDCacheConfig(cache_dir=str(tmp_path / "writer_test")) + tier = SSDCacheTier(config) + tier.start_writer() + yield tier + tier.close() + + def test_spill_writes_entry_to_disk(self, tier_with_writer): + tokens = (1, 2, 3, 4, 5) + keys = MockMLXArray(np.random.randn(1, 8, 16, 64).astype(np.float16)) + values = MockMLXArray(np.random.randn(1, 8, 16, 64).astype(np.float16)) + layer = MockKVCacheLayer(keys=keys, values=values, offset=16) + cache = [layer] + + tier_with_writer.enqueue_spill(tokens, cache, memory_bytes=2048) + + # Wait for async write to complete (with timeout) + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if tier_with_writer._stats.spill_count > 0: + break + time.sleep(0.05) + + assert tier_with_writer._stats.spill_count == 1 + assert tier_with_writer._stats.spill_bytes > 0 + + # Verify entry in index + result = tier_with_writer._index.lookup_exact(tokens) + assert result is not None + assert result["num_tokens"] == 5 + + def test_spill_atomic_write(self, tier_with_writer): + """Verify no partial files exist after spill.""" + tokens = (10, 20, 30) + keys = MockMLXArray(np.random.randn(1, 4, 8, 32).astype(np.float16)) + values = MockMLXArray(np.random.randn(1, 4, 8, 32).astype(np.float16)) + layer = MockKVCacheLayer(keys=keys, values=values, offset=8) + + tier_with_writer.enqueue_spill(tokens, [layer], memory_bytes=1024) + + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if tier_with_writer._stats.spill_count > 0: + break + time.sleep(0.05) + + # Check that no .tmp files remain + data_dir = tier_with_writer._data_dir + for root, dirs, files in os.walk(data_dir): + for f in files: + assert not f.endswith(".tmp"), f"Temp file left behind: {f}" + + def test_spill_queue_full_drops(self, tmp_path): + """When queue is full, new spills are dropped (not blocking).""" + config = SSDCacheConfig( + cache_dir=str(tmp_path / "queue_full_test"), + spill_queue_size=2, + ) + tier = SSDCacheTier(config) + # Don't start writer — queue will fill up + keys = MockMLXArray(np.zeros((1, 1, 1, 1), dtype=np.float16)) + values = MockMLXArray(np.zeros((1, 1, 1, 1), dtype=np.float16)) + layer = MockKVCacheLayer(keys=keys, values=values, offset=1) + + # Enqueue more than capacity + for i in range(5): + tier.enqueue_spill((i,), [layer], memory_bytes=100) + + # Queue should be at capacity, not have raised + assert tier._spill_queue.qsize() <= 2 + tier.close() + + +from unittest.mock import MagicMock +from vllm_mlx.memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig + + +class MockKVCacheForSpill: + """Mock KVCache with keys, values, offset for memory estimation.""" + + def __init__(self, size_bytes: int): + # Use real numpy arrays so serialization works + self.keys = MockMLXArray(np.zeros((1, 4, 8, 16), dtype=np.float16)) + self.values = MockMLXArray(np.zeros((1, 4, 8, 16), dtype=np.float16)) + self.offset = 8 + + +class TestSpillPath: + """Tests for eviction -> SSD spill integration.""" + + def test_evict_lru_calls_ssd_spill(self, tmp_path): + model = MagicMock() + config = MemoryCacheConfig(max_memory_mb=1, max_entries=3) + cache = MemoryAwarePrefixCache(model, config) + + ssd_config = SSDCacheConfig(cache_dir=str(tmp_path / "spill_test")) + ssd_tier = SSDCacheTier(ssd_config) + ssd_tier.start_writer() + cache.set_ssd_tier(ssd_tier) + + # Fill cache to capacity (3 entries) + for i in range(3): + kv = [MockKVCacheForSpill(1000)] + cache.store(list(range(i * 10, (i + 1) * 10)), kv) + + assert len(cache) == 3 + + # Store one more to trigger eviction + kv = [MockKVCacheForSpill(1000)] + cache.store(list(range(100, 110)), kv) + + # Wait for spill + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if ssd_tier._stats.spill_count > 0: + break + time.sleep(0.05) + + assert ssd_tier._stats.spill_count >= 1 + ssd_tier.close() + + def test_evict_without_ssd_tier_still_works(self): + """Eviction without SSD tier should work as before (discard).""" + model = MagicMock() + config = MemoryCacheConfig(max_memory_mb=1, max_entries=2) + cache = MemoryAwarePrefixCache(model, config) + + for i in range(5): + kv = [MockKVCacheForSpill(1000)] + cache.store(list(range(i * 10, (i + 1) * 10)), kv) + + # Should not raise, just discard + assert len(cache) <= 2 + + +import asyncio + + +class TestAsyncFetchPath: + """Tests for async SSD fetch with RAM budget reservation.""" + + @pytest.fixture + def populated_tier(self, tmp_path): + """Create an SSD tier with one pre-written entry.""" + config = SSDCacheConfig(cache_dir=str(tmp_path / "fetch_test")) + tier = SSDCacheTier(config) + tier.start_writer() + + tokens = (1, 2, 3, 4, 5) + keys = MockMLXArray(np.random.randn(1, 8, 16, 64).astype(np.float16)) + values = MockMLXArray(np.random.randn(1, 8, 16, 64).astype(np.float16)) + layer = MockKVCacheLayer(keys=keys, values=values, offset=16) + + tier.enqueue_spill(tokens, [layer], memory_bytes=2048) + + # Wait for write + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if tier._stats.spill_count > 0: + break + time.sleep(0.05) + + yield tier, tokens, keys, values + tier.close() + + def test_lookup_ssd_returns_candidate(self, populated_tier): + tier, tokens, _, _ = populated_tier + candidate = tier.lookup_ssd(tokens) + assert candidate is not None + assert candidate["num_tokens"] == 5 + assert candidate["memory_bytes"] == 2048 + + def test_lookup_ssd_miss(self, populated_tier): + tier, _, _, _ = populated_tier + candidate = tier.lookup_ssd((99, 98, 97)) + assert candidate is None + + def test_async_promote_reserves_budget_then_reads(self, populated_tier): + """RAM budget is reserved BEFORE the SSD read completes.""" + tier, tokens, keys, values = populated_tier + + budget_reserved = [] + budget_released = [] + + def reserve_fn(nbytes): + budget_reserved.append(nbytes) + return True # Budget available + + def release_fn(nbytes): + budget_released.append(nbytes) + + result = asyncio.run(tier.async_promote(tokens, reserve_fn, release_fn)) + + # Budget was reserved + assert len(budget_reserved) == 1 + assert budget_reserved[0] == 2048 + + # Result contains the cache layers + assert result is not None + assert len(result) == 1 # One layer + + # Stats updated + assert tier._stats.ssd_hits == 1 + assert tier._stats.reload_bytes > 0 + + def test_async_promote_budget_denied(self, populated_tier): + """When budget reservation fails, promote returns None.""" + tier, tokens, _, _ = populated_tier + + def reserve_fn(nbytes): + return False # Budget denied + + def release_fn(nbytes): + pass + + result = asyncio.run(tier.async_promote(tokens, reserve_fn, release_fn)) + + assert result is None + assert tier._stats.promotion_failures == 1 + + def test_async_promote_read_failure_releases_budget(self, populated_tier): + """If disk read fails after reservation, budget is released.""" + tier, tokens, _, _ = populated_tier + + budget_reserved = [] + budget_released = [] + + def reserve_fn(nbytes): + budget_reserved.append(nbytes) + return True + + def release_fn(nbytes): + budget_released.append(nbytes) + + # Corrupt the entry on disk + entry_hash = tier._entry_hash(tokens) + entry_dir = os.path.join(tier._data_dir, entry_hash) + manifest_path = os.path.join(entry_dir, "manifest.json") + with open(manifest_path, "w") as f: + f.write("corrupted!") + + result = asyncio.run(tier.async_promote(tokens, reserve_fn, release_fn)) + + assert result is None + # Budget was reserved then released + assert len(budget_reserved) == 1 + assert len(budget_released) == 1 + assert budget_released[0] == budget_reserved[0] + + +class TestCapacityManagement: + """Tests for SSD disk capacity management.""" + + @pytest.fixture + def small_tier(self, tmp_path): + """SSD tier with very small capacity for testing eviction.""" + config = SSDCacheConfig( + cache_dir=str(tmp_path / "capacity_test"), + max_size_gb=0.0001, # ~100KB + max_entries=3, + ) + tier = SSDCacheTier(config) + tier.start_writer() + yield tier + tier.close() + + def _write_and_wait(self, tier, tokens, size=1024): + keys = MockMLXArray(np.random.randn(1, 4, 8, 32).astype(np.float16)) + values = MockMLXArray(np.random.randn(1, 4, 8, 32).astype(np.float16)) + layer = MockKVCacheLayer(keys=keys, values=values, offset=8) + tier.enqueue_spill(tokens, [layer], memory_bytes=size) + + deadline = time.monotonic() + 5.0 + initial_count = tier._stats.spill_count + while time.monotonic() < deadline: + if tier._stats.spill_count > initial_count: + break + time.sleep(0.05) + + def test_disk_lru_eviction(self, small_tier): + """When disk capacity is exceeded, oldest entries are evicted.""" + # Fill with 3 entries (at max_entries limit) + for i in range(3): + self._write_and_wait(small_tier, tuple(range(i * 10, (i + 1) * 10))) + + assert small_tier._index.get_entry_count() == 3 + + # Write one more — should trigger disk LRU eviction + self._write_and_wait(small_tier, (100, 101, 102)) + + # Should still be at or below max_entries + assert small_tier._index.get_entry_count() <= 3 + + def test_startup_reconciliation(self, tmp_path): + """On startup, index is reconciled with files on disk.""" + cache_dir = str(tmp_path / "reconcile_test") + config = SSDCacheConfig(cache_dir=cache_dir) + tier = SSDCacheTier(config) + tier.start_writer() + + # Write an entry + tokens = (1, 2, 3) + keys = MockMLXArray(np.random.randn(1, 4, 8, 32).astype(np.float16)) + values = MockMLXArray(np.random.randn(1, 4, 8, 32).astype(np.float16)) + layer = MockKVCacheLayer(keys=keys, values=values, offset=8) + tier.enqueue_spill(tokens, [layer], memory_bytes=1024) + + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if tier._stats.spill_count > 0: + break + time.sleep(0.05) + + tier.close() + + # Delete the data file but leave the index entry + entry_hash = SSDCacheTier._entry_hash(tokens) + entry_dir = os.path.join(cache_dir, "data", entry_hash) + import shutil + + shutil.rmtree(entry_dir) + + # Re-open — reconciliation should detect the missing file + tier2 = SSDCacheTier(config) + tier2.reconcile() + + # The orphaned index entry should have been removed + assert tier2._index.lookup_exact(tokens) is None + tier2.close() + + +class TestCLIIntegration: + """Tests for CLI argument parsing of SSD cache flags.""" + + def test_scheduler_config_has_ssd_fields(self): + from vllm_mlx.scheduler import SchedulerConfig + + config = SchedulerConfig() + assert config.ssd_cache_dir is None + assert config.ssd_cache_max_gb == 10.0 + + def test_scheduler_config_custom_ssd(self): + from vllm_mlx.scheduler import SchedulerConfig + + config = SchedulerConfig( + ssd_cache_dir="/tmp/test-ssd", + ssd_cache_max_gb=5.0, + ) + assert config.ssd_cache_dir == "/tmp/test-ssd" + assert config.ssd_cache_max_gb == 5.0 + + +class TestMemoryCacheSSDCheck: + """Tests for MemoryAwarePrefixCache.check_ssd() method.""" + + def test_check_ssd_returns_candidate_on_miss(self, tmp_path): + model = MagicMock() + config = MemoryCacheConfig(max_memory_mb=1, max_entries=2) + cache = MemoryAwarePrefixCache(model, config) + + ssd_config = SSDCacheConfig(cache_dir=str(tmp_path / "check_ssd_test")) + ssd_tier = SSDCacheTier(ssd_config) + ssd_tier.start_writer() + cache.set_ssd_tier(ssd_tier) + + # Write an entry to SSD directly + tokens = (1, 2, 3, 4, 5) + keys = MockMLXArray(np.random.randn(1, 4, 8, 32).astype(np.float16)) + values = MockMLXArray(np.random.randn(1, 4, 8, 32).astype(np.float16)) + layer = MockKVCacheLayer(keys=keys, values=values, offset=8) + ssd_tier.enqueue_spill(tokens, [layer], memory_bytes=1024) + + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if ssd_tier._stats.spill_count > 0: + break + time.sleep(0.05) + + # RAM fetch should miss + result, remaining = cache.fetch(list(tokens)) + assert result is None + + # But SSD check should find it + candidate = cache.check_ssd(list(tokens)) + assert candidate is not None + assert candidate["num_tokens"] == 5 + + ssd_tier.close() + + def test_check_ssd_returns_none_without_tier(self): + model = MagicMock() + config = MemoryCacheConfig(max_memory_mb=1) + cache = MemoryAwarePrefixCache(model, config) + + candidate = cache.check_ssd([1, 2, 3]) + assert candidate is None + + def test_check_ssd_returns_none_on_ram_hit(self, tmp_path): + """When RAM has a hit, check_ssd should return None (not needed).""" + model = MagicMock() + config = MemoryCacheConfig(max_memory_mb=1) + cache = MemoryAwarePrefixCache(model, config) + + kv = [MockKVCacheForSpill(1000)] + cache.store([1, 2, 3], kv) + + # RAM hit exists + result, _ = cache.fetch([1, 2, 3]) + assert result is not None + + # check_ssd should indicate no SSD needed + candidate = cache.check_ssd([1, 2, 3]) + assert candidate is None + + +class TestIntegrationSpillAndFetch: + """End-to-end tests: spill from RAM -> SSD -> promote back.""" + + @pytest.fixture + def cache_with_ssd(self, tmp_path): + """RAM cache + SSD tier, small limits to force eviction.""" + model = MagicMock() + config = MemoryCacheConfig(max_memory_mb=1, max_entries=2) + ram_cache = MemoryAwarePrefixCache(model, config) + + ssd_config = SSDCacheConfig( + cache_dir=str(tmp_path / "integration_test"), + max_size_gb=1.0, + ) + ssd_tier = SSDCacheTier(ssd_config) + ssd_tier.start_writer() + ram_cache.set_ssd_tier(ssd_tier) + + yield ram_cache, ssd_tier + ssd_tier.close() + + def _make_kv_layer(self, offset=16): + keys = MockMLXArray(np.random.randn(1, 8, offset, 64).astype(np.float16)) + values = MockMLXArray(np.random.randn(1, 8, offset, 64).astype(np.float16)) + return MockKVCacheLayer(keys=keys, values=values, offset=offset) + + def test_full_round_trip(self, cache_with_ssd): + """Store -> evict to SSD -> fetch miss in RAM -> find in SSD.""" + ram_cache, ssd_tier = cache_with_ssd + + # Store 2 entries (fills RAM cache) + tokens_a = list(range(0, 10)) + tokens_b = list(range(10, 20)) + ram_cache.store(tokens_a, [self._make_kv_layer()]) + ram_cache.store(tokens_b, [self._make_kv_layer()]) + assert len(ram_cache) == 2 + + # Store a third entry — should evict tokens_a to SSD + tokens_c = list(range(20, 30)) + ram_cache.store(tokens_c, [self._make_kv_layer()]) + + # Wait for SSD spill + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if ssd_tier._stats.spill_count > 0: + break + time.sleep(0.05) + + assert ssd_tier._stats.spill_count >= 1 + + # RAM miss for tokens_a + result, remaining = ram_cache.fetch(tokens_a) + assert result is None + + # SSD should have tokens_a + candidate = ram_cache.check_ssd(tokens_a) + assert candidate is not None + + # Promote from SSD + async def do_promote(): + return await ssd_tier.async_promote( + tuple(tokens_a), + lambda n: True, # Budget always available + lambda n: None, + ) + + promoted = asyncio.run(do_promote()) + assert promoted is not None + assert len(promoted) == 1 # One layer + assert ssd_tier._stats.ssd_hits == 1 + + def test_hybrid_cache_round_trip(self, tmp_path): + """ArraysCache layers survive SSD round-trip.""" + config = SSDCacheConfig(cache_dir=str(tmp_path / "hybrid_test")) + tier = SSDCacheTier(config) + tier.start_writer() + + tokens = (1, 2, 3) + arr0 = MockMLXArray(np.random.randn(1, 64, 128).astype(np.float16)) + arr1 = MockMLXArray(np.random.randn(1, 64, 128).astype(np.float16)) + layer = MockArraysCacheLayer(state=[arr0, arr1]) + + tier.enqueue_spill(tokens, [layer], memory_bytes=2048) + + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if tier._stats.spill_count > 0: + break + time.sleep(0.05) + + async def do_promote(): + return await tier.async_promote( + tokens, + lambda n: True, + lambda n: None, + ) + + promoted = asyncio.run(do_promote()) + assert promoted is not None + assert len(promoted) == 1 + assert "state" in promoted[0] + assert len(promoted[0]["state"]) == 2 + + # Verify data integrity + np.testing.assert_array_almost_equal( + np.array(promoted[0]["state"][0]), + np.array(arr0), + ) + + tier.close() + + def test_capacity_eviction_end_to_end(self, tmp_path): + """Entries beyond max_entries are evicted from SSD.""" + config = SSDCacheConfig( + cache_dir=str(tmp_path / "cap_e2e"), + max_entries=2, + ) + tier = SSDCacheTier(config) + tier.start_writer() + + for i in range(4): + tokens = tuple(range(i * 10, (i + 1) * 10)) + layer = MockKVCacheLayer( + keys=MockMLXArray(np.zeros((1, 1, 1, 1), dtype=np.float16)), + values=MockMLXArray(np.zeros((1, 1, 1, 1), dtype=np.float16)), + offset=1, + ) + tier.enqueue_spill(tokens, [layer], memory_bytes=100) + + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if tier._stats.spill_count > i: + break + time.sleep(0.05) + + # Only max_entries (2) should remain + assert tier._index.get_entry_count() <= 2 + + # The latest entries should still be there + assert tier._index.lookup_exact(tuple(range(30, 40))) is not None + + tier.close() + + def test_metrics_accuracy(self, cache_with_ssd): + """Verify all SSD metrics are updated correctly.""" + ram_cache, ssd_tier = cache_with_ssd + + # Fill RAM and force one spill + ram_cache.store(list(range(10)), [self._make_kv_layer()]) + ram_cache.store(list(range(10, 20)), [self._make_kv_layer()]) + ram_cache.store(list(range(20, 30)), [self._make_kv_layer()]) + + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if ssd_tier._stats.spill_count > 0: + break + time.sleep(0.05) + + stats = ssd_tier.get_stats() + assert stats["spill_count"] >= 1 + assert stats["spill_bytes"] > 0 + + # Promote the spilled entry + evicted_tokens = tuple(range(10)) # First stored, first evicted + candidate = ssd_tier.lookup_ssd(evicted_tokens) + if candidate is not None: + + async def promote(): + return await ssd_tier.async_promote( + evicted_tokens, lambda n: True, lambda n: None + ) + + asyncio.run(promote()) + stats = ssd_tier.get_stats() + assert stats["ssd_hits"] >= 1 + assert stats["reload_bytes"] > 0 + assert stats["avg_reload_latency_ms"] > 0 diff --git a/tests/test_streaming_fence_stripper.py b/tests/test_streaming_fence_stripper.py new file mode 100644 index 000000000..6fe19a84f --- /dev/null +++ b/tests/test_streaming_fence_stripper.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for ``StreamingJsonFenceStripper`` — strips markdown code fences +from streamed ``response_format`` content. +""" + +from vllm_mlx.api.tool_calling import StreamingJsonFenceStripper + + +def _stream(chunks): + """Feed ``chunks`` through a fresh stripper and collect emitted text.""" + s = StreamingJsonFenceStripper() + out = [] + for c in chunks: + out.append(s.feed(c)) + out.append(s.finalize()) + return "".join(out) + + +class TestStreamingFenceStripperNoFence: + def test_single_delta_no_fence(self): + assert _stream(['{"answer": true}']) == '{"answer": true}' + + def test_multiple_deltas_no_fence(self): + assert _stream(["{", '"a"', ":", "1", "}"]) == '{"a":1}' + + def test_empty_input(self): + assert _stream([]) == "" + + def test_only_empty_deltas(self): + assert _stream(["", "", ""]) == "" + + def test_content_with_inner_backticks_preserved(self): + # Backticks inside a string value must not be interpreted as a fence. + text = '{"code": "x `y` z"}' + assert _stream([text]) == text + + +class TestStreamingFenceStripperLeadingFence: + def test_leading_backticks_json_newline(self): + assert _stream(['```json\n{"a": 1}']) == '{"a": 1}' + + def test_leading_backticks_json(self): + assert _stream(['```json{"a": 1}']) == '{"a": 1}' + + def test_leading_backticks_newline(self): + assert _stream(['```\n{"a": 1}']) == '{"a": 1}' + + def test_leading_bare_backticks(self): + assert _stream(['```{"a": 1}']) == '{"a": 1}' + + def test_leading_whitespace_then_fence(self): + assert _stream([' \n ```json\n{"a": 1}']) == '{"a": 1}' + + def test_fence_split_across_deltas_at_backticks(self): + # Classic case: tokenizer splits ``` from ``json``. + assert _stream(["```", 'json\n{"a":1}']) == '{"a":1}' + + def test_fence_split_inside_json_word(self): + assert _stream(["```jso", 'n\n{"a":1}']) == '{"a":1}' + + def test_fence_split_after_each_char(self): + assert ( + _stream(["`", "`", "`", "j", "s", "o", "n", "\n", '{"a":1}']) == '{"a":1}' + ) + + def test_no_fence_but_starts_with_backtick_char(self): + # A single backtick should NOT be treated as a fence prefix forever — + # as soon as a non-fence-prefix char arrives, we emit. + assert _stream(["`hello`"]) == "`hello`" + + +class TestStreamingFenceStripperTrailingFence: + def test_trailing_backticks(self): + assert _stream(['{"a": 1}```']) == '{"a": 1}' + + def test_trailing_newline_backticks(self): + assert _stream(['{"a": 1}\n```']) == '{"a": 1}' + + def test_trailing_newline_backticks_newline(self): + assert _stream(['{"a": 1}\n```\n']) == '{"a": 1}' + + def test_trailing_fence_split_across_deltas(self): + assert _stream(['{"a": 1}', "\n`", "``\n"]) == '{"a": 1}' + + def test_full_wrap_single_delta(self): + assert _stream(['```json\n{"a": 1}\n```']) == '{"a": 1}' + + def test_full_wrap_split_deltas(self): + assert _stream(["```json\n", '{"a": ', "1}", "\n", "```"]) == '{"a": 1}' + + +class TestStreamingFenceStripperOrdering: + def test_no_fence_emits_incrementally(self): + s = StreamingJsonFenceStripper() + # Long-enough chunk: should emit everything except last 5 chars now. + first = s.feed('{"result": "ok", "count": 42}') + # 29 chars in; held back = 5, so 24 emitted. + assert first == '{"result": "ok", "count"' + last = s.finalize() + assert last == ": 42}" + assert first + last == '{"result": "ok", "count": 42}' + + def test_leading_fence_emits_past_holdback(self): + s = StreamingJsonFenceStripper() + out1 = s.feed('```json\n{"a": 1, "b": 2, "c": 3}') + # After stripping fence: '{"a": 1, "b": 2, "c": 3}' (24 chars). + # Emit 24 - 5 = 19 chars now. + assert out1 == '{"a": 1, "b": 2, "c' + out2 = s.finalize() + assert out2 == '": 3}' + assert out1 + out2 == '{"a": 1, "b": 2, "c": 3}' + + def test_leading_fence_only_no_content(self): + # Stream ends with just the fence and no JSON — should emit nothing. + assert _stream(["```json\n"]) == "" + + def test_leading_partial_fence_only(self): + # Stream ends with just a partial fence — finalize should strip it. + assert _stream(["```js"]) == "" + + +class TestStreamingFenceStripperEdgeCases: + def test_content_ending_with_backtick_in_string(self): + # Single trailing backtick at end of content — held back then emitted. + text = '{"note": "end `"}' + assert _stream([text]) == text + + def test_fence_with_extra_whitespace_around_closing(self): + # Extra whitespace after closing fence still strips. + assert _stream(['{"a": 1}\n``` ']) == '{"a": 1}' + + def test_array_output(self): + assert _stream(["```json\n[1, 2, 3]\n```"]) == "[1, 2, 3]" + + def test_empty_object(self): + assert _stream(["```json\n{}\n```"]) == "{}" diff --git a/tests/test_structured_output.py b/tests/test_structured_output.py index 75eeee8c5..3d12ab87c 100644 --- a/tests/test_structured_output.py +++ b/tests/test_structured_output.py @@ -126,6 +126,172 @@ def test_nested_json(self): result = extract_json_from_text(text) assert result == {"outer": {"inner": {"deep": "value"}}} + def test_chatty_preamble_with_json(self): + """ + Test the classic Minimax / chain-of-thought failure mode: + the model explains itself before emitting JSON. + """ + text = ( + "Let me format this as JSON:\n" '```json\n{"name": "John", "age": 25}\n```' + ) + result = extract_json_from_text(text) + assert result == {"name": "John", "age": 25} + + def test_unterminated_markdown_fence(self): + """ + Test truncation: model started ```json fence but ran out of + tokens before closing it. Previously returned None. + """ + text = 'Let me format this as JSON:\n```json\n{"name": "John", "age": 25}' + result = extract_json_from_text(text) + assert result == {"name": "John", "age": 25} + + def test_truncated_json_unclosed_object(self): + """Test repair of JSON truncated mid-object (max_tokens hit).""" + text = '{"name": "John", "age": 25, "city": "Prague"' + result = extract_json_from_text(text) + assert result == {"name": "John", "age": 25, "city": "Prague"} + + def test_truncated_json_unclosed_string(self): + """Test repair of JSON truncated mid-string.""" + text = '{"name": "John", "city": "Pra' + result = extract_json_from_text(text) + assert result == {"name": "John", "city": "Pra"} + + def test_truncated_json_unclosed_nested(self): + """Test repair of JSON truncated deep inside nested structures.""" + text = '{"user": {"name": "John", "items": [1, 2, 3' + result = extract_json_from_text(text) + assert result == {"user": {"name": "John", "items": [1, 2, 3]}} + + def test_truncated_json_dangling_key(self): + """Test repair when truncation leaves a dangling key.""" + text = '{"name": "John", "age":' + result = extract_json_from_text(text) + # Dangling "age": should get repaired to null or dropped. + assert result is not None + assert result.get("name") == "John" + + def test_balanced_scan_prefers_valid_json(self): + """ + Test that balanced-brace scanning picks the first balanced JSON, + not a greedy region between first { and last }. + """ + text = 'First: {"a": 1}. Garbage: {broken}' + result = extract_json_from_text(text) + assert result == {"a": 1} + + def test_json_with_escaped_braces_in_string(self): + """Test balanced scanner respects strings containing braces.""" + text = '{"template": "Hello {user}!", "count": 5}' + result = extract_json_from_text(text) + assert result == {"template": "Hello {user}!", "count": 5} + + def test_json_with_escaped_quotes(self): + """Test balanced scanner respects escaped quotes inside strings.""" + text = '{"text": "He said \\"hi\\"", "ok": true}' + result = extract_json_from_text(text) + assert result == {"text": 'He said "hi"', "ok": True} + + +class TestRawJsonToolCallHijackPrevention: + """ + Regression tests for the MiniMax-M2 bug where ``response_format`` with a + schema that contained ``"name"`` (e.g. person name) was hijacked into a + fake ``function.name`` tool call by the raw-JSON tool-call fallback. + """ + + def test_response_format_not_hijacked_no_tools(self): + """JSON data with 'name' field must not become a tool call.""" + from vllm_mlx.api.tool_calling import parse_tool_calls + + # Simulate Minimax emitting user-schema JSON after response_format. + text = '{"name": "John", "age": 25}' + # Request has NO tools — any parse_tool_calls output is a hijack. + cleaned, tool_calls = parse_tool_calls(text, request={"tools": None}) + assert tool_calls is None, f"Hijacked JSON into tool_calls: {tool_calls}" + + def test_response_format_with_response_format_no_tools(self): + """ + Even when the request has ``response_format`` set and no tools, + a JSON object with 'name' must not become a tool call. + """ + from vllm_mlx.api.tool_calling import parse_tool_calls + + text = '{"name": "Alice", "age": 30, "city": "Prague"}' + cleaned, tool_calls = parse_tool_calls( + text, + request={ + "tools": None, + "response_format": {"type": "json_object"}, + }, + ) + assert tool_calls is None + + def test_genuine_tool_call_still_detected(self): + """ + Regression guard: a genuine tool call (``name`` + ``arguments``) + MUST still be detected when the caller passed ``tools``. + """ + from vllm_mlx.api.tool_calling import parse_tool_calls + + text = '{"name": "get_weather", "arguments": {"city": "Prague"}}' + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + }, + } + ] + cleaned, tool_calls = parse_tool_calls(text, request={"tools": tools}) + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "get_weather" + + def test_no_name_no_arguments_not_tool_call(self): + """Plain JSON without 'name' stays as content.""" + from vllm_mlx.api.tool_calling import parse_tool_calls + + text = '{"result": 42, "status": "ok"}' + cleaned, tool_calls = parse_tool_calls( + text, request={"tools": [{"type": "function"}]} + ) + assert tool_calls is None + + def test_name_only_not_tool_call(self): + """ + ``{"name": "x"}`` without ``"arguments"`` must NOT be a tool call. + This is the specific shape that hijacked MiniMax-M2 output. + """ + from vllm_mlx.api.tool_calling import _parse_raw_json_tool_calls + + assert _parse_raw_json_tool_calls('{"name": "John"}') is None + assert _parse_raw_json_tool_calls('{"name": "John", "age": 25}') is None + + def test_looks_like_tool_call_helper(self): + """Direct coverage for the _looks_like_tool_call heuristic.""" + from vllm_mlx.api.tool_calling import _looks_like_tool_call + + # Genuine tool calls. + assert _looks_like_tool_call({"name": "f", "arguments": {}}) + assert _looks_like_tool_call({"name": "f", "arguments": '{"a":1}'}) + assert _looks_like_tool_call( + {"name": "get_weather", "arguments": {"city": "Prague"}} + ) + # Not tool calls. + assert not _looks_like_tool_call({"name": "John", "age": 25}) + assert not _looks_like_tool_call({"name": "", "arguments": {}}) + assert not _looks_like_tool_call({"arguments": {}}) + assert not _looks_like_tool_call({"name": 123, "arguments": {}}) + assert not _looks_like_tool_call({"name": "f", "arguments": 42}) + assert not _looks_like_tool_call("not a dict") + assert not _looks_like_tool_call(None) + class TestParseJsonOutput: """Tests for parse_json_output function.""" @@ -253,6 +419,17 @@ def test_json_object(self): assert "valid JSON" in result assert "only" in result.lower() + def test_json_object_has_strict_rules(self): + """ + Test that json_object prompt contains explicit no-markdown / + no-preamble rules (added to prevent chatty model failures). + """ + result = build_json_system_prompt({"type": "json_object"}) + assert result is not None + assert "STRICT" in result + assert "markdown" in result.lower() + assert "preamble" in result.lower() or "Preamble" in result + def test_json_schema(self): """Test prompt for json_schema mode.""" response_format = { @@ -272,6 +449,20 @@ def test_json_schema(self): assert "A person object" in result assert "JSON Schema" in result + def test_json_schema_has_strict_rules(self): + """Test json_schema prompt also carries the strict output rules.""" + response_format = { + "type": "json_schema", + "json_schema": { + "name": "person", + "schema": {"type": "object"}, + }, + } + result = build_json_system_prompt(response_format) + assert result is not None + assert "STRICT" in result + assert "markdown" in result.lower() + def test_json_schema_model(self): """Test prompt with ResponseFormat model.""" json_schema = ResponseFormatJsonSchema( diff --git a/tests/test_tool_choice_forced.py b/tests/test_tool_choice_forced.py new file mode 100644 index 000000000..f97c10e67 --- /dev/null +++ b/tests/test_tool_choice_forced.py @@ -0,0 +1,332 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for forced tool_choice support. + +Covers: +- tool_choice={"type":"function","function":{"name":"X"}} (forced specific tool) +- tool_choice="required" (must call some tool) +- QwenToolParser.SUPPORTS_NATIVE_TOOL_FORMAT +- QwenToolParser empty wrapper cleanup +""" + +import json + +import pytest + +from vllm_mlx.server import ( + _apply_forced_tool_choice, + _get_forced_tool_name, + _tool_name, +) +from vllm_mlx.tool_parsers import QwenToolParser + +# --------------------------------------------------------------------------- +# Helper fixtures +# --------------------------------------------------------------------------- + +WEATHER_TOOL = { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + }, +} + +CALC_TOOL = { + "type": "function", + "function": { + "name": "calculate", + "description": "Calculate math", + "parameters": { + "type": "object", + "properties": {"expression": {"type": "string"}}, + "required": ["expression"], + }, + }, +} + + +# --------------------------------------------------------------------------- +# _get_forced_tool_name +# --------------------------------------------------------------------------- + + +class TestGetForcedToolName: + """Test extraction of forced tool name from tool_choice.""" + + def test_dict_with_function_name(self): + tc = {"type": "function", "function": {"name": "calculate"}} + assert _get_forced_tool_name(tc) == "calculate" + + def test_dict_missing_function_key(self): + tc = {"type": "function"} + assert _get_forced_tool_name(tc) is None + + def test_dict_wrong_type(self): + tc = {"type": "tool", "function": {"name": "foo"}} + assert _get_forced_tool_name(tc) is None + + def test_string_auto(self): + assert _get_forced_tool_name("auto") is None + + def test_string_none(self): + assert _get_forced_tool_name("none") is None + + def test_string_required(self): + assert _get_forced_tool_name("required") is None + + def test_none_value(self): + assert _get_forced_tool_name(None) is None + + def test_empty_dict(self): + assert _get_forced_tool_name({}) is None + + +# --------------------------------------------------------------------------- +# _tool_name +# --------------------------------------------------------------------------- + + +class TestToolName: + """Test _tool_name helper.""" + + def test_extracts_name(self): + assert _tool_name(WEATHER_TOOL) == "get_weather" + + def test_no_function_key(self): + assert _tool_name({"type": "function"}) is None + + def test_function_not_dict(self): + assert _tool_name({"function": "not_a_dict"}) is None + + +# --------------------------------------------------------------------------- +# _apply_forced_tool_choice +# --------------------------------------------------------------------------- + + +class TestApplyForcedToolChoice: + """Test _apply_forced_tool_choice for forced and required modes.""" + + def _make_messages(self): + return [{"role": "user", "content": "Hello"}] + + def test_forced_specific_tool_filters(self): + """Forced tool_choice should filter tools to only the named one.""" + tools = [WEATHER_TOOL, CALC_TOOL] + msgs = self._make_messages() + tc = {"type": "function", "function": {"name": "calculate"}} + + new_tools, new_msgs = _apply_forced_tool_choice(tc, tools, msgs) + + assert len(new_tools) == 1 + assert _tool_name(new_tools[0]) == "calculate" + + def test_forced_specific_tool_injects_instruction(self): + """Forced tool_choice should inject a system instruction.""" + tools = [WEATHER_TOOL, CALC_TOOL] + msgs = self._make_messages() + tc = {"type": "function", "function": {"name": "calculate"}} + + _, new_msgs = _apply_forced_tool_choice(tc, tools, msgs) + + # Should have a system message prepended or appended + system_msgs = [m for m in new_msgs if m.get("role") == "system"] + assert len(system_msgs) >= 1 + assert "calculate" in system_msgs[0]["content"] + + def test_forced_disables_thinking(self): + """Forced tool_choice should set enable_thinking=False in chat_kwargs.""" + tools = [WEATHER_TOOL] + msgs = self._make_messages() + tc = {"type": "function", "function": {"name": "get_weather"}} + kwargs = {} + + _apply_forced_tool_choice(tc, tools, msgs, kwargs) + + assert kwargs.get("enable_thinking") is False + + def test_forced_unknown_tool_raises(self): + """Forced tool_choice with unknown function name should raise ValueError.""" + tools = [WEATHER_TOOL, CALC_TOOL] + msgs = self._make_messages() + tc = {"type": "function", "function": {"name": "nonexistent"}} + + with pytest.raises(ValueError, match="not found in tools"): + _apply_forced_tool_choice(tc, tools, msgs) + + def test_required_injects_instruction(self): + """tool_choice='required' should inject a system instruction.""" + tools = [WEATHER_TOOL, CALC_TOOL] + msgs = self._make_messages() + + new_tools, new_msgs = _apply_forced_tool_choice("required", tools, msgs) + + # All tools kept + assert len(new_tools) == 2 + # System instruction injected + system_msgs = [m for m in new_msgs if m.get("role") == "system"] + assert len(system_msgs) >= 1 + assert "MUST" in system_msgs[0]["content"] + + def test_required_does_not_disable_thinking(self): + """tool_choice='required' should NOT disable thinking.""" + tools = [WEATHER_TOOL] + msgs = self._make_messages() + kwargs = {} + + _apply_forced_tool_choice("required", tools, msgs, kwargs) + + assert "enable_thinking" not in kwargs + + def test_auto_is_noop(self): + """tool_choice='auto' should not modify anything.""" + tools = [WEATHER_TOOL, CALC_TOOL] + msgs = self._make_messages() + + new_tools, new_msgs = _apply_forced_tool_choice("auto", tools, msgs) + + assert new_tools is tools # Same object, not modified + assert new_msgs is msgs + + def test_none_tools_is_noop(self): + """Empty tools list should be a no-op.""" + msgs = self._make_messages() + + new_tools, new_msgs = _apply_forced_tool_choice("required", None, msgs) + + assert new_tools is None + assert new_msgs is msgs + + def test_forced_appends_to_existing_system(self): + """If system message already exists, instruction should be appended.""" + tools = [WEATHER_TOOL] + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + ] + tc = {"type": "function", "function": {"name": "get_weather"}} + + _, new_msgs = _apply_forced_tool_choice(tc, tools, msgs) + + system_content = new_msgs[0]["content"] + assert system_content.startswith("You are helpful.") + assert "get_weather" in system_content + + def test_does_not_mutate_original_messages(self): + """Original messages list should not be mutated.""" + tools = [WEATHER_TOOL] + msgs = [{"role": "user", "content": "Hello"}] + original_len = len(msgs) + tc = {"type": "function", "function": {"name": "get_weather"}} + + _, new_msgs = _apply_forced_tool_choice(tc, tools, msgs) + + assert len(msgs) == original_len # Original unchanged + assert len(new_msgs) != original_len # New one has system msg + + +# --------------------------------------------------------------------------- +# QwenToolParser — SUPPORTS_NATIVE_TOOL_FORMAT +# --------------------------------------------------------------------------- + + +class TestQwenNativeToolFormat: + """Test QwenToolParser native tool format support.""" + + def test_supports_native_format_class_attribute(self): + assert QwenToolParser.SUPPORTS_NATIVE_TOOL_FORMAT is True + + def test_supports_native_format_method(self): + assert QwenToolParser.supports_native_format() is True + + +# --------------------------------------------------------------------------- +# QwenToolParser — empty cleanup +# --------------------------------------------------------------------------- + + +class TestQwenToolCallCleanup: + """Test that empty wrappers are cleaned from content.""" + + def _parser(self): + return QwenToolParser(tokenizer=None) + + def test_function_style_cleans_wrapper(self): + """... should leave no content.""" + text = ( + "\n" + "Prague" + "\n" + ) + result = self._parser().extract_tool_calls(text) + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "get_weather" + # Content should be None or empty after cleanup + assert not result.content + + def test_multiple_function_calls_clean_wrappers(self): + """Multiple wrappers should all be cleaned.""" + text = ( + "\n" + "Paris" + "\n\n" + "\n" + "London" + "\n" + ) + result = self._parser().extract_tool_calls(text) + assert result.tools_called is True + assert len(result.tool_calls) == 2 + assert not result.content + + def test_text_before_tool_call_preserved(self): + """Text before the tool call block should be preserved.""" + text = ( + "Let me check the weather.\n" + "\n" + "Tokyo" + "\n" + ) + result = self._parser().extract_tool_calls(text) + assert result.tools_called is True + assert result.content == "Let me check the weather." + + def test_xml_style_no_double_cleanup(self): + """XML-style tool calls already clean their own tags.""" + text = '{"name": "get_weather", "arguments": {"location": "Berlin"}}' + result = self._parser().extract_tool_calls(text) + assert result.tools_called is True + assert not result.content + + def test_function_with_json_args_cleans_wrapper(self): + """{"key":"val"} wrapped in should clean.""" + text = ( + "\n" + '{"expression": "2+2"}\n' + "" + ) + result = self._parser().extract_tool_calls(text) + assert result.tools_called is True + assert result.tool_calls[0]["name"] == "calculate" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["expression"] == "2+2" + assert not result.content + + def test_think_tags_stripped_before_cleanup(self): + """ tags should be stripped, then cleaned.""" + text = ( + "Let me think about this...\n" + "\n" + "Rome" + "\n" + ) + result = self._parser().extract_tool_calls(text) + assert result.tools_called is True + assert result.tool_calls[0]["name"] == "get_weather" + assert not result.content diff --git a/tests/test_tool_parsers.py b/tests/test_tool_parsers.py index 4f3c287d1..6e0211a35 100644 --- a/tests/test_tool_parsers.py +++ b/tests/test_tool_parsers.py @@ -555,6 +555,16 @@ def test_detects_qwen_bracket(self, parser): assert result.tools_called assert result.tool_calls[0]["name"] == "add" + def test_detects_bare_bracket(self, parser): + """Test auto detection of bare bracket format.""" + text = '[read({"file_path": "/tmp/test.py"})]' + result = parser.extract_tool_calls(text) + + assert result.tools_called + assert result.tool_calls[0]["name"] == "read" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["file_path"] == "/tmp/test.py" + def test_detects_llama(self, parser): """Test auto detection of Llama format.""" text = '{"x": 2}' @@ -651,6 +661,39 @@ def test_tool_call_id_uniqueness(self): assert len(ids) == len(set(ids)), "Tool call IDs should be unique" +class TestBareBracketStreaming: + """Test streaming for bare bracket tool calls.""" + + def test_auto_streaming_bare_bracket(self): + """Auto parser should emit structured tool calls for bare bracket streaming.""" + parser = AutoToolParser() + + chunks = [ + "[read(", + '{"file_path": "/tmp/test.py"}', + ")]", + ] + accumulated = "" + tool_calls_found = False + + for chunk in chunks: + prev = accumulated + accumulated += chunk + r = parser.extract_tool_calls_streaming( + previous_text=prev, + current_text=accumulated, + delta_text=chunk, + ) + if r is not None and "tool_calls" in r: + tool_calls_found = True + assert r["tool_calls"][0]["function"]["name"] == "read" + args = json.loads(r["tool_calls"][0]["function"]["arguments"]) + assert args["file_path"] == "/tmp/test.py" + break + + assert tool_calls_found + + class TestStreamingParsing: """Test streaming tool call parsing.""" @@ -1243,6 +1286,29 @@ def test_streaming_function_format_complete(self, parser): break assert tool_calls_found + def test_streaming_bracket_call_closing_marker_split(self, parser): + """Qwen bracket calls should complete when ')' and ']' split chunks.""" + chunks = [ + '[Calling tool: add({"a": 1, "b": 2})', + "]", + ] + + accumulated = "" + emitted = None + for chunk in chunks: + previous = accumulated + accumulated += chunk + emitted = parser.extract_tool_calls_streaming( + previous_text=previous, + current_text=accumulated, + delta_text=chunk, + ) + + assert emitted is not None + assert "tool_calls" in emitted + assert emitted["tool_calls"][0]["function"]["name"] == "add" + assert emitted["tool_calls"][0]["function"]["arguments"] == ('{"a": 1, "b": 2}') + def test_streaming_partial_marker_buffered(self, parser): """Test that partial ' ChatCompletionRequest: stop=request.stop_sequences, tools=tools, tool_choice=tool_choice, + # Forward response_format as-is; ChatCompletionRequest coerces raw + # dicts into the strict ResponseFormat model via pydantic. + response_format=request.response_format, + chat_template_kwargs=request.chat_template_kwargs, ) diff --git a/vllm_mlx/api/anthropic_models.py b/vllm_mlx/api/anthropic_models.py index e8854a5e6..c09bab181 100644 --- a/vllm_mlx/api/anthropic_models.py +++ b/vllm_mlx/api/anthropic_models.py @@ -56,7 +56,7 @@ class AnthropicRequest(BaseModel): model: str messages: list[AnthropicMessage] system: str | list[dict] | None = None - max_tokens: int # Required in Anthropic API + max_tokens: int = Field(gt=0) # Required in Anthropic API temperature: float | None = None top_p: float | None = None stream: bool = False @@ -65,6 +65,12 @@ class AnthropicRequest(BaseModel): tool_choice: dict | None = None metadata: dict | None = None top_k: int | None = None + # OpenAI-compatible extension (not in the official Anthropic spec, but + # clients commonly forward it via extra_body or top-level for structured + # output / constrained decoding on this endpoint). + response_format: dict | None = None + # OpenAI-compatible extension for tokenizer chat template kwargs. + chat_template_kwargs: dict[str, Any] | None = None # ============================================================================= diff --git a/vllm_mlx/api/models.py b/vllm_mlx/api/models.py index 0ebe616eb..f54f7cff7 100644 --- a/vllm_mlx/api/models.py +++ b/vllm_mlx/api/models.py @@ -11,8 +11,9 @@ import time import uuid +from typing import Any -from pydantic import AliasChoices, BaseModel, Field +from pydantic import AliasChoices, BaseModel, Field, model_serializer # ============================================================================= # Content Types (for multimodal messages) @@ -162,7 +163,7 @@ class ChatCompletionRequest(BaseModel): min_p: float | None = None presence_penalty: float | None = None repetition_penalty: float | None = None - max_tokens: int | None = None + max_tokens: int | None = Field(default=None, gt=0) stream: bool = False stream_options: StreamOptions | None = ( None # Streaming options (include_usage, etc.) @@ -173,6 +174,8 @@ class ChatCompletionRequest(BaseModel): tool_choice: str | dict | None = None # "auto", "none", or specific tool # Structured output response_format: ResponseFormat | dict | None = None + # Extra kwargs forwarded to tokenizer.apply_chat_template + chat_template_kwargs: dict[str, Any] | None = None # MLLM-specific parameters video_fps: float | None = None video_max_frames: int | None = None @@ -203,6 +206,20 @@ class AssistantMessage(BaseModel): def reasoning(self) -> str | None: return self.reasoning_content + @model_serializer + def _serialize(self) -> dict: + """Serialize with OpenAI-compatible schema. + + - ``tool_calls`` and ``reasoning_content`` are omitted when None. + - ``content`` is always included (even as null) per OpenAI spec. + """ + d: dict = {"role": self.role, "content": self.content} + if self.reasoning_content is not None: + d["reasoning_content"] = self.reasoning_content + if self.tool_calls is not None: + d["tool_calls"] = [tc.model_dump() for tc in self.tool_calls] + return d + class ChatCompletionChoice(BaseModel): """A single choice in chat completion response.""" @@ -247,7 +264,7 @@ class CompletionRequest(BaseModel): min_p: float | None = None presence_penalty: float | None = None repetition_penalty: float | None = None - max_tokens: int | None = None + max_tokens: int | None = Field(default=None, gt=0) stream: bool = False stop: list[str] | None = None # Sampling penalties @@ -431,6 +448,43 @@ class EmbeddingResponse(BaseModel): usage: EmbeddingUsage = Field(default_factory=EmbeddingUsage) +# ============================================================================= +# Reranking +# ============================================================================= + + +class RerankRequest(BaseModel): + """Request for reranking documents against a query (Jina/Cohere convention).""" + + model: str + query: str + documents: list[str | dict] + top_n: int | None = None + return_documents: bool = True + + +class RerankResult(BaseModel): + """A single reranked document result.""" + + index: int + relevance_score: float + document: dict | None = None + + +class RerankUsage(BaseModel): + """Token usage for rerank requests.""" + + total_tokens: int = 0 + + +class RerankResponse(BaseModel): + """Response for reranking endpoint (Jina/Cohere convention).""" + + model: str + results: list[RerankResult] + usage: RerankUsage = Field(default_factory=RerankUsage) + + # ============================================================================= # Streaming (for SSE responses) # ============================================================================= @@ -451,6 +505,24 @@ class ChatCompletionChunkDelta(BaseModel): def reasoning(self) -> str | None: return self.reasoning_content + @model_serializer + def _serialize(self) -> dict: + """Serialize delta with only non-None fields. + + Per OpenAI streaming spec, delta objects only include fields that + carry new content. + """ + d: dict = {} + if self.role is not None: + d["role"] = self.role + if self.content is not None: + d["content"] = self.content + if self.reasoning_content is not None: + d["reasoning_content"] = self.reasoning_content + if self.tool_calls is not None: + d["tool_calls"] = self.tool_calls + return d + class ChatCompletionChunkChoice(BaseModel): """A single choice in a streaming chunk.""" diff --git a/vllm_mlx/api/responses_models.py b/vllm_mlx/api/responses_models.py new file mode 100644 index 000000000..5c1e3bcd3 --- /dev/null +++ b/vllm_mlx/api/responses_models.py @@ -0,0 +1,318 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Pydantic models for the OpenAI-compatible Responses API. + +This intentionally implements the subset needed for local coding-agent +workflows: text messages, function tools, function call outputs, and SSE +streaming events. The object and event shapes follow the conventions used by +OpenAI's gpt-oss reference server and llama.cpp's OpenAI-compatible server. +""" + +import time +import uuid +from typing import Literal +from typing import Any + +from pydantic import BaseModel, Field, computed_field + + +class ResponseTextFormat(BaseModel): + """Output text format configuration.""" + + type: Literal["text", "json_object"] = "text" + + +class ResponseTextConfig(BaseModel): + """Text output configuration.""" + + format: ResponseTextFormat = Field(default_factory=ResponseTextFormat) + + +class ResponseReasoningConfig(BaseModel): + """Reasoning configuration.""" + + effort: Literal["none", "minimal", "low", "medium", "high", "xhigh"] | None = None + + +class ResponseTextContentPart(BaseModel): + """A text content part for message items.""" + + type: Literal["text", "input_text", "output_text"] = "output_text" + text: str + annotations: list[dict] = Field(default_factory=list) + logprobs: list[dict] = Field(default_factory=list) + + +class ResponseReasoningTextPart(BaseModel): + """A reasoning text content part.""" + + type: Literal["reasoning_text"] = "reasoning_text" + text: str + + +class ResponseReasoningSummaryTextPart(BaseModel): + """A reasoning summary item.""" + + type: Literal["summary_text"] = "summary_text" + text: str + + +class ResponseMessageItem(BaseModel): + """A Responses API message item.""" + + id: str | None = None + type: Literal["message"] = "message" + role: Literal["system", "user", "assistant", "developer"] = "assistant" + content: str | list[ResponseTextContentPart] = Field(default_factory=list) + status: Literal["in_progress", "completed", "incomplete"] | None = "completed" + + +class ResponseReasoningItem(BaseModel): + """A reasoning output item.""" + + id: str | None = None + type: Literal["reasoning"] = "reasoning" + summary: list[ResponseReasoningSummaryTextPart] = Field(default_factory=list) + content: list[ResponseReasoningTextPart] = Field(default_factory=list) + status: Literal["in_progress", "completed", "incomplete"] | None = "completed" + + +class ResponseFunctionCallItem(BaseModel): + """A function call output item.""" + + id: str | None = None + type: Literal["function_call"] = "function_call" + call_id: str + name: str + arguments: str + status: Literal["in_progress", "completed", "incomplete"] = "completed" + + +class ResponseFunctionCallOutputItem(BaseModel): + """A tool result item passed back into a later request.""" + + type: Literal["function_call_output"] = "function_call_output" + call_id: str + output: str + + +class ResponseFunctionTool(BaseModel): + """A function tool definition.""" + + type: Literal["function"] = "function" + name: str + description: str | None = "" + parameters: dict = Field( + default_factory=lambda: {"type": "object", "properties": {}} + ) + strict: bool = False + + +class ResponsesInputTokenDetails(BaseModel): + """Input token breakdown.""" + + cached_tokens: int = 0 + + +class ResponsesOutputTokenDetails(BaseModel): + """Output token breakdown.""" + + reasoning_tokens: int = 0 + + +class ResponsesUsage(BaseModel): + """Responses API token usage.""" + + input_tokens: int + output_tokens: int + total_tokens: int + input_tokens_details: ResponsesInputTokenDetails = Field( + default_factory=ResponsesInputTokenDetails + ) + output_tokens_details: ResponsesOutputTokenDetails = Field( + default_factory=ResponsesOutputTokenDetails + ) + + +class ResponseError(BaseModel): + """Error payload.""" + + code: str + message: str + + +class ResponseIncompleteDetails(BaseModel): + """Incomplete response details.""" + + reason: str + + +class ResponsesRequest(BaseModel): + """Request payload for /v1/responses.""" + + model: str + input: ( + str + | list[ + ResponseMessageItem + | ResponseReasoningItem + | ResponseFunctionCallItem + | ResponseFunctionCallOutputItem + | dict + ] + ) + instructions: str | None = None + max_output_tokens: int | None = None + stream: bool = False + tools: list[ResponseFunctionTool | dict] = Field(default_factory=list) + tool_choice: str | dict | None = "auto" + parallel_tool_calls: bool = True + previous_response_id: str | None = None + temperature: float | None = None + top_p: float | None = None + chat_template_kwargs: dict[str, Any] | None = None + metadata: dict = Field(default_factory=dict) + text: ResponseTextConfig = Field(default_factory=ResponseTextConfig) + reasoning: ResponseReasoningConfig | None = None + store: bool = True + truncation: str = "disabled" + user: str | None = None + + +class ResponseObject(BaseModel): + """Response object for /v1/responses.""" + + id: str = Field(default_factory=lambda: f"resp_{uuid.uuid4().hex}") + object: Literal["response"] = "response" + created_at: int = Field(default_factory=lambda: int(time.time())) + status: Literal["completed", "failed", "incomplete", "in_progress"] = "completed" + background: bool = False + error: ResponseError | None = None + incomplete_details: ResponseIncompleteDetails | None = None + instructions: str | None = None + max_output_tokens: int | None = None + max_tool_calls: int | None = None + metadata: dict = Field(default_factory=dict) + model: str + output: list[ + ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + ] = Field(default_factory=list) + parallel_tool_calls: bool = True + previous_response_id: str | None = None + text: ResponseTextConfig = Field(default_factory=ResponseTextConfig) + tool_choice: str | dict | None = "auto" + tools: list[ResponseFunctionTool | dict] = Field(default_factory=list) + top_p: float = 1.0 + temperature: float | None = None + truncation: str = "disabled" + usage: ResponsesUsage | None = None + user: str | None = None + store: bool = True + + @computed_field + @property + def output_text(self) -> str: + """Concatenate assistant text content into the convenience field.""" + text_parts: list[str] = [] + for item in self.output: + if not isinstance(item, ResponseMessageItem): + continue + if isinstance(item.content, str): + text_parts.append(item.content) + continue + for part in item.content: + if part.type == "output_text": + text_parts.append(part.text) + return "".join(text_parts) + + +class ResponsesEventBase(BaseModel): + """Base event fields.""" + + sequence_number: int + + +class ResponseCreatedEvent(ResponsesEventBase): + type: Literal["response.created"] = "response.created" + response: ResponseObject + + +class ResponseInProgressEvent(ResponsesEventBase): + type: Literal["response.in_progress"] = "response.in_progress" + response: ResponseObject + + +class ResponseCompletedEvent(ResponsesEventBase): + type: Literal["response.completed"] = "response.completed" + response: ResponseObject + + +class ResponseOutputItemAddedEvent(ResponsesEventBase): + type: Literal["response.output_item.added"] = "response.output_item.added" + output_index: int + item: ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + + +class ResponseOutputItemDoneEvent(ResponsesEventBase): + type: Literal["response.output_item.done"] = "response.output_item.done" + output_index: int + item: ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + + +class ResponseContentPartAddedEvent(ResponsesEventBase): + type: Literal["response.content_part.added"] = "response.content_part.added" + item_id: str + output_index: int + content_index: int + part: ResponseTextContentPart | ResponseReasoningTextPart + + +class ResponseContentPartDoneEvent(ResponsesEventBase): + type: Literal["response.content_part.done"] = "response.content_part.done" + item_id: str + output_index: int + content_index: int + part: ResponseTextContentPart | ResponseReasoningTextPart + + +class ResponseOutputTextDeltaEvent(ResponsesEventBase): + type: Literal["response.output_text.delta"] = "response.output_text.delta" + item_id: str + output_index: int + content_index: int + delta: str + logprobs: list[dict] = Field(default_factory=list) + + +class ResponseOutputTextDoneEvent(ResponsesEventBase): + type: Literal["response.output_text.done"] = "response.output_text.done" + item_id: str + output_index: int + content_index: int + text: str + logprobs: list[dict] = Field(default_factory=list) + + +class ResponseReasoningTextDeltaEvent(ResponsesEventBase): + type: Literal["response.reasoning_text.delta"] = "response.reasoning_text.delta" + item_id: str + output_index: int + content_index: int + delta: str + + +class ResponseReasoningTextDoneEvent(ResponsesEventBase): + type: Literal["response.reasoning_text.done"] = "response.reasoning_text.done" + item_id: str + output_index: int + content_index: int + text: str + + +class ResponseFunctionCallArgumentsDeltaEvent(ResponsesEventBase): + type: Literal["response.function_call_arguments.delta"] = ( + "response.function_call_arguments.delta" + ) + item_id: str + output_index: int + delta: str diff --git a/vllm_mlx/api/tool_calling.py b/vllm_mlx/api/tool_calling.py index 364b65993..5bdd24b79 100644 --- a/vllm_mlx/api/tool_calling.py +++ b/vllm_mlx/api/tool_calling.py @@ -21,6 +21,34 @@ from .models import FunctionCall, ResponseFormat, ToolCall +def _looks_like_tool_call(obj: Any) -> bool: + """ + Heuristic: decide whether a parsed JSON object really represents a tool + call as opposed to user data that happens to carry a ``"name"`` field. + + The OpenAI tool-call wire format ALWAYS has both ``"name"`` and + ``"arguments"``. Accepting bare ``{"name": ...}`` (previous behaviour) + caused ``response_format={"type": "json_schema"}`` payloads with a + ``name`` field to be hijacked as fake tool calls (observed on + MiniMax-M2: ``{"name": "John", "age": 25}`` -> ``function.name="John"``). + + Args: + obj: Parsed JSON object. + + Returns: + True if obj looks like a tool call, False otherwise. + """ + if not isinstance(obj, dict): + return False + if "name" not in obj or "arguments" not in obj: + return False + if not isinstance(obj["name"], str) or not obj["name"]: + return False + # ``arguments`` must be a JSON-encoded string or a dict per OpenAI spec. + args = obj["arguments"] + return isinstance(args, (dict, str)) + + def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]: """ Parse raw JSON tool calls from model output. @@ -30,6 +58,9 @@ def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]: - Multiple objects separated by commas: {...}, {...} - JSON array: [{...}, {...}] + Only accepts objects that carry both ``name`` AND ``arguments`` fields + to avoid hijacking user data emitted via ``response_format``. + Args: text: Raw model output text @@ -45,11 +76,13 @@ def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]: if text.startswith("["): try: parsed = json.loads(text) - if isinstance(parsed, list) and all( - isinstance(item, dict) and "name" in item for item in parsed + if ( + isinstance(parsed, list) + and parsed + and all(_looks_like_tool_call(item) for item in parsed) ): return [ - {"name": item["name"], "arguments": item.get("arguments", {})} + {"name": item["name"], "arguments": item["arguments"]} for item in parsed ] except json.JSONDecodeError: @@ -71,9 +104,9 @@ def _parse_raw_json_tool_calls(text: str) -> Optional[List[dict]]: json_str = text[start : i + 1] try: obj = json.loads(json_str) - if isinstance(obj, dict) and "name" in obj: + if _looks_like_tool_call(obj): tool_calls.append( - {"name": obj["name"], "arguments": obj.get("arguments", {})} + {"name": obj["name"], "arguments": obj["arguments"]} ) except json.JSONDecodeError: pass @@ -280,8 +313,13 @@ def parse_tool_calls( # The user may want to see the model's reasoning process # Fallback: Raw JSON tool calls (lowest priority) - # Only try if no other formats matched - if not tool_calls: + # Only try if no other formats matched AND the caller actually asked for + # tools. When ``tools`` is not set, a bare JSON response is user data + # (e.g. from ``response_format``), not a tool invocation — the fallback + # would otherwise hijack ``{"name": "John", ...}`` into a fake + # ``function.name="John"`` tool call (observed on MiniMax-M2). + tools_requested = bool(request and request.get("tools")) + if not tool_calls and tools_requested: raw_json_calls = _parse_raw_json_tool_calls(cleaned_text) if raw_json_calls: for call_data in raw_json_calls: @@ -410,14 +448,145 @@ def validate_json_schema( return False, str(e.message) +def _scan_balanced_json(text: str, start: int) -> Optional[str]: + """ + Walk forward from ``start`` (which must point at ``{`` or ``[``) and + return the substring that represents the first balanced JSON value, + respecting strings and escapes. Returns ``None`` if the opening bracket + is never closed (truncated output). + """ + if start < 0 or start >= len(text): + return None + opener = text[start] + if opener not in "{[": + return None + closer = "}" if opener == "{" else "]" + depth = 0 + in_string = False + escape = False + for i in range(start, len(text)): + ch = text[i] + if in_string: + if escape: + escape = False + elif ch == "\\": + escape = True + elif ch == '"': + in_string = False + continue + if ch == '"': + in_string = True + continue + if ch == opener: + depth += 1 + elif ch == closer: + depth -= 1 + if depth == 0: + return text[start : i + 1] + return None + + +def _repair_truncated_json(fragment: str) -> Optional[Dict[str, Any]]: + """ + Attempt to parse a JSON fragment whose closing brackets were cut off + (e.g. because the model hit ``max_tokens`` mid-object). + + Strategy: scan once to determine the open-bracket stack and whether we + ended mid-string, then try a handful of repair candidates in order of + likelihood: + + 1. Close unterminated string, close brackets. + 2. Also strip a dangling ``,`` / ``:`` before closing. + 3. Also drop a dangling key (``"k":`` or bare ``"k"``) before closing. + 4. Drop a dangling partial token (number / true / fals / nul) before + closing. + + Returns the first candidate that ``json.loads`` accepts, or ``None``. + """ + if not fragment: + return None + stack: list[str] = [] + in_string = False + escape = False + for ch in fragment: + if in_string: + if escape: + escape = False + elif ch == "\\": + escape = True + elif ch == '"': + in_string = False + continue + if ch == '"': + in_string = True + elif ch in "{[": + stack.append(ch) + elif ch in "}]": + if stack and ( + (ch == "}" and stack[-1] == "{") or (ch == "]" and stack[-1] == "[") + ): + stack.pop() + if not stack and not in_string: + return None # Nothing to repair — caller already tried json.loads. + + def _close(text: str) -> str: + for opener in reversed(stack): + text += "}" if opener == "{" else "]" + return text + + # Base: close unterminated string if any. + base = fragment + if in_string: + if escape: + base = base[:-1] # drop trailing backslash so closing quote sticks + base += '"' + + candidates: list[str] = [] + + # 1. Close brackets directly. + candidates.append(_close(base)) + + # 2. Strip a dangling separator (trailing ``,`` / ``:`` / whitespace). + stripped_sep = re.sub(r"[,:\s]+$", "", base) + if stripped_sep != base: + candidates.append(_close(stripped_sep)) + + # 3. Drop a dangling key (``"k":`` or bare ``"k"``) inside an object. + if stack and stack[-1] == "{": + no_key = re.sub(r',?\s*"[^"]*"\s*:?\s*$', "", stripped_sep) + if no_key != stripped_sep: + candidates.append(_close(no_key)) + + # 4. Drop a dangling partial scalar (number / keyword / literal). + no_scalar = re.sub( + r",?\s*(?:-?\d+(?:\.\d*)?(?:[eE][+-]?\d*)?|t|tr|tru|f|fa|fal|fals|n|nu|nul)$", + "", + stripped_sep, + ) + if no_scalar != stripped_sep: + candidates.append(_close(no_scalar)) + + for candidate in candidates: + try: + return json.loads(candidate) + except json.JSONDecodeError: + continue + return None + + def extract_json_from_text(text: str) -> Optional[Dict[str, Any]]: """ Extract JSON from model output text. - Tries multiple strategies: + Tries multiple strategies, in order of specificity: + 1. Parse entire text as JSON - 2. Extract JSON from markdown code blocks - 3. Find JSON object/array in text + 2. Extract JSON from complete markdown code blocks (``` ... ```) + 3. Extract JSON from an unterminated markdown code block (``` json\\n{ ... ) + — handles the common "chatty + truncation" failure mode where the + model starts a ```json fence, never closes it, then hits max_tokens. + 4. Balanced-brace scan for the first ``{`` or ``[`` in the text + 5. Repair truncated JSON by closing unclosed brackets/strings Args: text: Raw model output text @@ -433,7 +602,7 @@ def extract_json_from_text(text: str) -> Optional[Dict[str, Any]]: except json.JSONDecodeError: pass - # Strategy 2: Extract from markdown code blocks + # Strategy 2: Extract from complete markdown code blocks # Match ```json ... ``` or ``` ... ``` code_block_pattern = r"```(?:json)?\s*([\s\S]*?)\s*```" matches = re.findall(code_block_pattern, text) @@ -443,23 +612,173 @@ def extract_json_from_text(text: str) -> Optional[Dict[str, Any]]: except json.JSONDecodeError: continue - # Strategy 3: Find JSON object or array in text - # Look for { ... } or [ ... ] - json_patterns = [ - r"(\{[\s\S]*\})", # Object - r"(\[[\s\S]*\])", # Array - ] - for pattern in json_patterns: - match = re.search(pattern, text) - if match: - try: - return json.loads(match.group(1)) - except json.JSONDecodeError: - continue + # Strategy 3: Unterminated markdown fence — take everything after the + # last ``` opener. Common truncation case for chatty models. + unterminated_fence = re.search(r"```(?:json)?\s*\n?([\s\S]*)$", text) + fenced_candidate: Optional[str] = None + if unterminated_fence: + fenced_candidate = unterminated_fence.group(1).strip() + # Drop a trailing ``` if it slipped through the greedy match. + if fenced_candidate.endswith("```"): + fenced_candidate = fenced_candidate[:-3].strip() + try: + return json.loads(fenced_candidate) + except json.JSONDecodeError: + pass + + # Strategy 4: Balanced-brace scan for the first JSON value anywhere. + # Preserves correctness better than a greedy regex (which would grab + # everything between the first ``{`` and the last ``}``). + for opener in ("{", "["): + idx = text.find(opener) + while idx != -1: + candidate = _scan_balanced_json(text, idx) + if candidate is not None: + try: + return json.loads(candidate) + except json.JSONDecodeError: + pass + idx = text.find(opener, idx + 1) + + # Strategy 5: Repair a truncated JSON prefix. Try the fenced candidate + # first (usually starts right at ``{``), then fall back to the earliest + # ``{``/``[`` in the full text. + candidates = [] + if fenced_candidate: + candidates.append(fenced_candidate) + for opener in ("{", "["): + idx = text.find(opener) + if idx != -1: + candidates.append(text[idx:]) + for fragment in candidates: + repaired = _repair_truncated_json(fragment) + if repaired is not None: + return repaired return None +class StreamingJsonFenceStripper: + """Strip markdown code fences from streamed content when response_format is set. + + Without guided decoding, chat models often wrap their JSON output in markdown + fences (```json ... ```) even when the system prompt says not to. The non- + streaming path strips those via ``extract_json_from_text`` / ``parse_json_output``, + but the streaming path used to emit the raw deltas, so clients got + ``"```json{...}```"`` instead of ``"{...}"``. + + This filter buffers just enough text to detect: + * a leading fence like ``"```"``, ``"```json"``, ``"```\\n"`` or + ``"```json\\n"`` (with optional leading whitespace), possibly split + across SSE deltas, and + * a trailing fence like ``"```"`` or ``"\\n```\\n"`` on stream end. + + Leading-whitespace and leading fences are consumed; trailing fences are + dropped in :meth:`finalize`. Non-fenced content passes through with at most + a ``_TAIL_HOLDBACK``-char delay. + """ + + # Opening fence forms to strip, longest first (longest match wins). + _OPENINGS = ("```json\n", "```json", "```\n", "```") + # Characters held back at the tail to detect a trailing fence across deltas. + # 5 covers ``"\n```\n"``. + _TAIL_HOLDBACK = 5 + + def __init__(self) -> None: + self._buf: str = "" + self._past_opening: bool = False + + def feed(self, delta: str) -> str: + """Append a content delta and return the portion safe to emit now.""" + if not delta: + return "" + + self._buf += delta + + if not self._past_opening: + ls = self._buf.lstrip() + if not ls: + # Still only whitespace — wait for content. + return "" + + # If ``ls`` is a strict prefix of any opening (shorter than the + # opening), the rest of the fence might still arrive in the next + # delta — keep buffering. + for opening in self._OPENINGS: + if len(ls) < len(opening) and opening.startswith(ls): + return "" + + # Try to match a complete opening fence (longest first). + matched: Optional[str] = None + for opening in self._OPENINGS: + if ls.startswith(opening): + matched = opening + break + + if matched is not None: + # Drop the fence and any immediate whitespace that followed. + self._buf = ls[len(matched) :].lstrip() + else: + # Not a fence — keep the stripped text. + self._buf = ls + self._past_opening = True + + # Dynamic holdback: walk backwards across any trailing whitespace and + # backticks so a closing fence cannot straddle the emit boundary, and + # additionally keep at least ``_TAIL_HOLDBACK`` chars to absorb fences + # that span delta boundaries. + buf = self._buf + i = len(buf) + while i > 0 and (buf[i - 1] == "`" or buf[i - 1].isspace()): + i -= 1 + safe_end = min(i, len(buf) - self._TAIL_HOLDBACK) + if safe_end <= 0: + return "" + + to_emit = buf[:safe_end] + self._buf = buf[safe_end:] + return to_emit + + def finalize(self) -> str: + """Flush the remaining buffer, dropping any trailing fence.""" + tail = self._buf + self._buf = "" + if not tail: + return "" + + if not self._past_opening: + # Stream ended before we ever transitioned past the (potential) + # opening — either drop a strict-prefix fence that never finished + # arriving, or strip a complete leading fence. + ls = tail.lstrip() + # Check strict-prefix FIRST: ``"```js"`` is a prefix of + # ``"```json\n"`` and should be dropped entirely rather than + # partially matched by the bare ``"```"`` opening. + is_partial_prefix = False + for opening in self._OPENINGS: + if len(ls) < len(opening) and opening.startswith(ls): + is_partial_prefix = True + break + if is_partial_prefix: + ls = "" + else: + matched: Optional[str] = None + for opening in self._OPENINGS: + if ls.startswith(opening): + matched = opening + break + if matched is not None: + ls = ls[len(matched) :].lstrip() + tail = ls + self._past_opening = True + + stripped = tail.rstrip() + for closing in ("\n```", "```"): + if stripped.endswith(closing): + return stripped[: -len(closing)].rstrip() + return tail + + def parse_json_output( text: str, response_format: Optional[Union[ResponseFormat, Dict[str, Any]]] = None ) -> Tuple[str, Optional[Dict[str, Any]], bool, Optional[str]]: @@ -564,10 +883,19 @@ def build_json_system_prompt( if format_type == "text": return None + strict_rules = ( + "Output rules (STRICT):\n" + "- Your first character MUST be `{` (or `[`).\n" + "- Your last character MUST be `}` (or `]`).\n" + "- Do NOT wrap the JSON in a markdown code block (no ``` fences).\n" + "- Do NOT prepend any preamble such as " + '"Here is the JSON" or "Let me format this".\n' + "- Do NOT include comments, trailing explanations, or chain-of-thought.\n" + ) + if format_type == "json_object": return ( - "You must respond with valid JSON only. " - "Do not include any explanation or text outside the JSON object." + "You must respond with a single valid JSON value only.\n\n" + strict_rules ) if format_type == "json_schema": @@ -576,13 +904,93 @@ def build_json_system_prompt( name = json_schema_spec.get("name", "response") description = json_schema_spec.get("description", "") - prompt = f"You must respond with valid JSON matching the '{name}' schema." + prompt = f"You must respond with a single valid JSON value matching the '{name}' schema." if description: prompt += f" {description}" - prompt += ( - f"\n\nJSON Schema:\n```json\n{json.dumps(schema, indent=2)}\n```\n\n" - "Respond with only the JSON object, no additional text or explanation." - ) + prompt += f"\n\nJSON Schema:\n```json\n{json.dumps(schema, indent=2)}\n```\n\n" + prompt += strict_rules return prompt return None + + +def build_json_logits_processor( + response_format: ResponseFormat | dict[str, Any] | None, + tokenizer: Any, +): + """ + Build a logits processor that constrains generation to valid JSON matching + ``response_format``. + + Unlike :func:`build_json_system_prompt` which nudges the model via the + system prompt, this processor masks logits at every generation step so + the model *cannot* emit invalid JSON (grammar-guided decoding). + + Args: + response_format: ``ResponseFormat`` specification (or dict). + tokenizer: The tokenizer used by the engine. May be a HF tokenizer, + a ``mlx_lm.TokenizerWrapper``, or a VLM ``processor``; the + underlying tokenizer is resolved automatically. + + Returns: + A callable ``(tokens, logits) -> logits`` suitable for passing to + ``mlx_lm.stream_generate`` via ``logits_processors``. ``None`` when + no constraint is needed (e.g. ``type=text``) or when constrained + decoding cannot be enabled (missing optional dependency, tokenizer + incompatibility) — in that case the caller should fall back to the + system-prompt path. + """ + if response_format is None: + return None + + # Normalize to dict + if isinstance(response_format, ResponseFormat): + format_type = response_format.type + schema: dict | None = None + if response_format.json_schema is not None: + schema = response_format.json_schema.schema_ + elif isinstance(response_format, dict): + format_type = response_format.get("type", "text") + json_schema_spec = response_format.get("json_schema") or {} + if isinstance(json_schema_spec, dict): + schema = json_schema_spec.get("schema") + else: + schema = getattr(json_schema_spec, "schema_", None) or getattr( + json_schema_spec, "schema", None + ) + else: + return None + + if format_type == "text": + return None + + if format_type not in ("json_object", "json_schema"): + return None + + try: + from ..constrained import ( + JSONSchemaLogitsProcessor, + LMFormatEnforcerNotAvailableError, + is_available, + ) + except ImportError: + # Constrained decoding module could not be imported; fall back. + return None + + if not is_available(): + # ``lm-format-enforcer`` optional dependency missing; fall back. + return None + + # ``json_schema`` without an actual schema degrades to ``json_object`` + # (both paths pass ``schema=None`` to the processor). + if format_type == "json_object" or (format_type == "json_schema" and not schema): + schema = None + + try: + return JSONSchemaLogitsProcessor(schema=schema, tokenizer=tokenizer) + except LMFormatEnforcerNotAvailableError: + return None + except Exception: + # Malformed schema or tokenizer issue — fall back to prompt-only + # mode. Callers still run post-hoc validation as a safety net. + return None diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index 7ef9ffef6..6d70ed721 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -384,6 +384,45 @@ def is_mllm_model(model_name: str) -> bool: is_vlm_model = is_mllm_model +# ============================================================================= +# Media Content Detection +# ============================================================================= + +MEDIA_CONTENT_TYPES = frozenset( + { + "image_url", + "video_url", + "audio_url", + "image", + "video", + "audio", + } +) + + +def has_media_content(messages: list) -> bool: + """Check if any message contains media content (images, video, audio). + + Handles both plain dicts (``msg.get("content")``) and Pydantic-style + objects (``msg.content``) so it works in both engine and server contexts. + """ + for msg in messages: + content = ( + msg.get("content") + if isinstance(msg, dict) + else getattr(msg, "content", None) + ) + if isinstance(content, list): + for part in content: + if isinstance(part, dict): + part_type = part.get("type") + else: + part_type = getattr(part, "type", None) + if part_type in MEDIA_CONTENT_TYPES: + return True + return False + + # ============================================================================= # Multimodal Content Extraction # ============================================================================= diff --git a/vllm_mlx/audio_limits.py b/vllm_mlx/audio_limits.py new file mode 100644 index 000000000..f2515f714 --- /dev/null +++ b/vllm_mlx/audio_limits.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Resource limits for optional audio endpoints.""" + +import os +import tempfile +from pathlib import Path +from typing import Protocol + +from fastapi import HTTPException + +DEFAULT_MAX_AUDIO_UPLOAD_MB = 25 +DEFAULT_MAX_AUDIO_UPLOAD_BYTES = DEFAULT_MAX_AUDIO_UPLOAD_MB * 1024 * 1024 +DEFAULT_MAX_TTS_INPUT_CHARS = 4096 +UPLOAD_CHUNK_SIZE = 1024 * 1024 + + +class AsyncReadableUpload(Protocol): + filename: str | None + + async def read(self, size: int = -1) -> bytes: ... + + +async def save_upload_with_limit( + file: AsyncReadableUpload, + *, + max_bytes: int, + default_suffix: str = ".wav", + chunk_size: int = UPLOAD_CHUNK_SIZE, +) -> str: + """ + Stream an uploaded file to disk while enforcing a hard byte limit. + + This prevents large audio uploads from being buffered entirely in memory. + """ + suffix = Path(file.filename or "").suffix or default_suffix + total_bytes = 0 + + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: + tmp_path = tmp.name + try: + while True: + chunk = await file.read(chunk_size) + if not chunk: + break + total_bytes += len(chunk) + if total_bytes > max_bytes: + raise HTTPException( + status_code=413, + detail=( + f"Audio upload too large: {total_bytes} bytes exceeds " + f"the configured limit of {max_bytes} bytes." + ), + ) + tmp.write(chunk) + except Exception: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + raise + + return tmp_path + + +def validate_tts_input_length(text: str, *, max_chars: int) -> None: + """Reject oversized TTS requests before synthesis starts.""" + if len(text) > max_chars: + raise HTTPException( + status_code=413, + detail=( + f"TTS input too long: {len(text)} characters exceeds the configured " + f"limit of {max_chars} characters." + ), + ) diff --git a/vllm_mlx/bench_serve.py b/vllm_mlx/bench_serve.py new file mode 100644 index 000000000..bec5d21a5 --- /dev/null +++ b/vllm_mlx/bench_serve.py @@ -0,0 +1,1378 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Serving benchmark for vllm-mlx. + +Measures end-to-end HTTP performance of a running vllm-mlx server: +- Time to First Token (TTFT) +- Time Per Output Token (TPOT) +- End-to-end latency +- Generation and prompt throughput +- Concurrent request handling +- KV cache hit rates +- Metal memory utilization + +This module has no MLX dependency and can be imported on any platform. +It is a pure HTTP client that talks to a running OpenAI-compatible server. +""" + +import asyncio +import csv as csv_mod +import dataclasses as _dataclasses +import io +import itertools +import json +import logging +import math +import platform +import re +import statistics +import sys +import time +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional + +import httpx +from tabulate import tabulate as _tabulate + +# --------------------------------------------------------------------------- +# Prompt set loading +# --------------------------------------------------------------------------- + +_BUILTIN_DIR = Path(__file__).parent / "bench_serve_prompts" +_BUILTIN_NAMES = {"short", "medium", "long", "thinking"} + + +def load_prompt_set(name_or_path: str) -> list[list[dict]]: + """Load a prompt set by builtin name or file path. + + Builtin sets (``short``, ``medium``, ``long``, ``thinking``) are loaded + from the ``bench_serve_prompts/`` directory next to this module. Any + other value is treated as a filesystem path and loaded directly. + + Two file formats are accepted (detected automatically): + + 1. **Flat** — list of single message dicts. Each dict becomes a + single-message prompt. Backwards-compatible with the original format. + + ``[{"role": "user", "content": "..."}, ...]`` + + 2. **Multi-message** — list of message-dict lists. Each inner list is a + full chat history (e.g. ``[system, user]``). Use this format when you + want to benchmark with system prompts that match an ``--warm-prompts`` + warm-up, or to simulate multi-turn conversation. + + ``[[{"role":"system","content":"..."}, {"role":"user","content":"..."}], ...]`` + + Returns: + A list of message-dict lists, i.e. every entry is a full chat history. + Flat-format files are normalized to single-element lists. + + Raises: + FileNotFoundError: If ``name_or_path`` is not a known builtin name and + the path does not exist, or if a builtin name is requested but its + JSON file is missing from the package. + ValueError: If the file shape is not recognised. + """ + if name_or_path in _BUILTIN_NAMES: + target = _BUILTIN_DIR / f"{name_or_path}.json" + if not target.exists(): + raise FileNotFoundError( + f"Builtin prompt set '{name_or_path}' not found at {target}" + ) + with target.open() as fh: + raw = json.load(fh) + else: + path = Path(name_or_path).expanduser() + if not path.exists(): + raise FileNotFoundError( + f"Unknown prompt set name or missing file: '{name_or_path}'. " + f"Builtin names are: {sorted(_BUILTIN_NAMES)}" + ) + with path.open() as fh: + raw = json.load(fh) + + if not isinstance(raw, list) or not raw: + raise ValueError(f"Prompt file must be a non-empty JSON list: {name_or_path}") + + # Auto-detect format: dict entries = flat; list entries = multi-message. + first = raw[0] + if isinstance(first, dict): + # Flat format: wrap each message in a single-element list. + return [[msg] for msg in raw] + if isinstance(first, list): + return raw + raise ValueError( + f"Prompt entries must be dict or list, got {type(first).__name__} " + f"in {name_or_path}" + ) + + +# --------------------------------------------------------------------------- +# Result dataclass +# --------------------------------------------------------------------------- + + +@dataclass +class BenchServeResult: + """Aggregated results from a single bench-serve run configuration.""" + + # --- Identity --- + run_id: str = "" + timestamp: str = "" + tag: str = "" + + # --- Hardware --- + chip: str = "" + gpu_cores: int = 0 + memory_gb: float = 0.0 + bandwidth_gbs: float = 0.0 + os_version: str = "" + + # --- Runtime --- + model_id: str = "" + model_type: str = "" + engine_type: str = "" + mtp_enabled: bool = False + specprefill: bool = False + kv_quant: str = "" + cache_type: str = "" + + # --- Config --- + prompt_set: str = "" + concurrency: int = 1 + max_tokens: int = 256 + enable_thinking: Optional[bool] = None + extra_body: str = "" + repetition: int = 0 + prompt_tokens: int = 0 + + # --- Latency (milliseconds) --- + ttft_ms: float = 0.0 + tpot_ms: float = 0.0 + e2e_latency_ms: float = 0.0 + + # --- Throughput --- + gen_tps: float = 0.0 + prompt_tps: float = 0.0 + throughput_tps: float = 0.0 + requests_per_s: float = 0.0 + + # --- Memory (gigabytes) --- + metal_active_gb: float = 0.0 + metal_peak_gb: float = 0.0 + metal_cache_gb: float = 0.0 + + # --- Cache --- + cache_hits: int = 0 + cache_misses: int = 0 + cache_hit_rate: float = 0.0 + tokens_saved: int = 0 + + # --- Validation --- + validated: bool = True + + +# --------------------------------------------------------------------------- +# Sweep configuration type alias and combinatorial expansion +# --------------------------------------------------------------------------- + +# (prompt_set, concurrency, thinking, extra_body, repetition_index) +SweepConfig = tuple[str, int, Optional[bool], str, int] + + +def expand_sweep( + prompt_sets: list[str], + concurrencies: list[int], + thinking_values: list[Optional[bool]], + extra_bodies: list[str], + repetitions: int, +) -> list[SweepConfig]: + """Expand sweep parameters into a flat list of configurations. + + Performs the full Cartesian product of all input dimensions and then + unfolds each combination across ``repetitions`` repetition indices + (0-based). + + Args: + prompt_sets: Names or paths of prompt sets to include. + concurrencies: Concurrency levels to test (e.g. ``[1, 4, 16]``). + thinking_values: Values for ``enable_thinking`` (e.g. + ``[None, True, False]``). + extra_bodies: JSON strings (or empty string) to pass as extra body + parameters on each request. + repetitions: Number of times to repeat each unique combination. + Each repeat gets a distinct 0-based repetition index. + + Returns: + A list of :data:`SweepConfig` tuples in the order:: + + (prompt_set, concurrency, thinking, extra_body, repetition_index) + """ + configs: list[SweepConfig] = [] + for prompt_set, concurrency, thinking, extra_body, rep in itertools.product( + prompt_sets, concurrencies, thinking_values, extra_bodies, range(repetitions) + ): + configs.append((prompt_set, concurrency, thinking, extra_body, rep)) + return configs + + +# --------------------------------------------------------------------------- +# Task 3: Server auto-detection + hardware fingerprint +# --------------------------------------------------------------------------- + + +def parse_health_response(data: dict) -> dict: + """Extract model identity fields from a GET /health response. + + Args: + data: Parsed JSON body from the /health endpoint. Expected shape:: + + {"status": "healthy", "model_loaded": True, + "model_name": "...", "model_type": "llm"|"mllm"} + + Returns: + ``{"model_name": str, "model_type": str}`` + """ + return { + "model_name": data.get("model_name", ""), + "model_type": data.get("model_type", ""), + } + + +def parse_status_response(data: dict) -> dict: + """Extract metal and cache info from a GET /v1/status response. + + Args: + data: Parsed JSON body from the /v1/status endpoint. Metal info is + expected under ``data["metal"]`` and cache info under + ``data["cache"]``. Missing keys are handled gracefully. + + Returns: + ``{"model": str, "metal_active_gb": float, "metal_peak_gb": float, + "metal_cache_gb": float, "cache_type": str}`` + """ + metal = data.get("metal") or {} + cache = data.get("cache") or {} + return { + "model": data.get("model", ""), + "metal_active_gb": float(metal.get("active_gb") or 0.0), + "metal_peak_gb": float(metal.get("peak_gb") or 0.0), + "metal_cache_gb": float(metal.get("cache_gb") or 0.0), + "cache_type": cache.get("type", "") or "", + } + + +def parse_metrics_text(text: str) -> dict: + """Parse Prometheus text exposition format from GET /metrics. + + Extracts the three prefix-cache counters used for bench reporting. + + Args: + text: Raw response body from the /metrics endpoint. + + Returns: + ``{"cache_hits": int, "cache_misses": int, "tokens_saved": int}`` + — each value defaults to ``0`` when the metric line is absent. + """ + + def _extract(metric_name: str) -> int: + pattern = rf"^{re.escape(metric_name)}\s+(\d+)" + m = re.search(pattern, text, re.MULTILINE) + return int(m.group(1)) if m else 0 + + return { + "cache_hits": _extract("vllm_prefix_cache_hits_total"), + "cache_misses": _extract("vllm_prefix_cache_misses_total"), + "tokens_saved": _extract("vllm_prefix_cache_tokens_saved_total"), + } + + +def detect_hardware_fingerprint() -> dict: + """Return a hardware fingerprint dict for the current machine. + + Tries to import :func:`vllm_mlx.optimizations.detect_hardware` (which + requires MLX). Falls back to reading ``hw.memsize`` via ``sysctl`` when + MLX is unavailable. ``os_version`` is always obtained from + :func:`platform.platform`. + + Returns: + ``{"chip": str, "gpu_cores": int, "memory_gb": float, + "bandwidth_gbs": float, "os_version": str}`` + """ + os_version = platform.platform() + + try: + from .optimizations import detect_hardware # type: ignore[import] + + hw = detect_hardware() + return { + "chip": hw.chip_name, + "gpu_cores": hw.gpu_cores, + "memory_gb": hw.total_memory_gb, + "bandwidth_gbs": hw.memory_bandwidth_gbs, + "os_version": os_version, + } + except Exception: + pass + + # Fallback: use sysctl for memory, leave chip/cores/bandwidth unknown. + memory_gb = 0.0 + try: + import subprocess + + result = subprocess.run( + ["sysctl", "-n", "hw.memsize"], + capture_output=True, + text=True, + check=True, + ) + memory_gb = int(result.stdout.strip()) / (1024**3) + except Exception: + pass + + return { + "chip": "", + "gpu_cores": 0, + "memory_gb": memory_gb, + "bandwidth_gbs": 0.0, + "os_version": os_version, + } + + +async def auto_detect_runtime(client: httpx.AsyncClient, base_url: str) -> dict: + """Query the running server and return a runtime descriptor dict. + + Hits ``/health``, ``/v1/models``, and ``/v1/status`` in sequence. + Each call is wrapped in an :exc:`httpx.HTTPError` guard so a missing + endpoint does not abort the whole detection. + + Args: + client: An open :class:`httpx.AsyncClient`. + base_url: Base URL of the server (e.g. ``"http://localhost:8080"``). + + Returns: + Dict with keys: ``model_id``, ``model_type``, ``engine_type``, + ``mtp_enabled``, ``specprefill``, ``kv_quant``, ``cache_type``, + ``metal_active_gb``, ``metal_peak_gb``, ``metal_cache_gb``. + """ + result: dict = { + "model_id": "", + "model_type": "", + "engine_type": "", + "mtp_enabled": False, + "specprefill": False, + "kv_quant": "", + "cache_type": "", + "metal_active_gb": 0.0, + "metal_peak_gb": 0.0, + "metal_cache_gb": 0.0, + } + + # /health + try: + resp = await client.get(f"{base_url}/health") + resp.raise_for_status() + health = parse_health_response(resp.json()) + result["model_type"] = health.get("model_type", "") + except httpx.HTTPError: + pass + + # /v1/models + try: + resp = await client.get(f"{base_url}/v1/models") + resp.raise_for_status() + models_data = resp.json() + models = models_data.get("data") or [] + if models: + result["model_id"] = models[0].get("id", "") + except httpx.HTTPError: + pass + + # /v1/status + try: + resp = await client.get(f"{base_url}/v1/status") + resp.raise_for_status() + status = parse_status_response(resp.json()) + result["cache_type"] = status.get("cache_type", "") + result["metal_active_gb"] = status.get("metal_active_gb", 0.0) + result["metal_peak_gb"] = status.get("metal_peak_gb", 0.0) + result["metal_cache_gb"] = status.get("metal_cache_gb", 0.0) + raw = resp.json() + result["engine_type"] = raw.get("engine_type", "") + result["mtp_enabled"] = bool(raw.get("mtp_enabled", False)) + result["specprefill"] = bool(raw.get("specprefill", False)) + result["kv_quant"] = raw.get("kv_quant", "") or "" + except httpx.HTTPError: + pass + + return result + + +async def scrape_metrics(client: httpx.AsyncClient, base_url: str) -> dict: + """Scrape Prometheus metrics from the server. + + Args: + client: An open :class:`httpx.AsyncClient`. + base_url: Base URL of the server. + + Returns: + Parsed metrics dict (see :func:`parse_metrics_text`), or an empty + dict if the endpoint is unreachable. + """ + try: + resp = await client.get(f"{base_url}/metrics") + resp.raise_for_status() + return parse_metrics_text(resp.text) + except Exception: + return {} + + +# --------------------------------------------------------------------------- +# Task 4: SSE streaming core + token counting + request timing +# --------------------------------------------------------------------------- + + +def parse_sse_line(line: str) -> Optional[dict]: + """Parse one Server-Sent Events line from a streaming chat completion. + + Args: + line: A single raw line from the SSE stream (may or may not include + a trailing newline — it is stripped before processing). + + Returns: + ``None`` for blank lines, comment lines (starting with ``:``) and the + ``data: [DONE]`` sentinel. For all other ``data:`` lines the JSON is + parsed and a dict is returned:: + + {"content": str, "finish_reason": Optional[str], "usage": Optional[dict]} + + Missing keys (``choices``, ``delta``, ``content``) are handled + gracefully and default to empty string / ``None``. + """ + line = line.strip() + if not line: + return None + if line.startswith(":"): + return None + if line == "data: [DONE]": + return None + if not line.startswith("data: "): + return None + + payload = line[len("data: ") :] + try: + chunk = json.loads(payload) + except json.JSONDecodeError: + return None + + choices = chunk.get("choices") or [] + delta = choices[0].get("delta", {}) if choices else {} + content = delta.get("content", "") or "" + finish_reason = choices[0].get("finish_reason") if choices else None + usage = chunk.get("usage") + + return { + "content": content, + "finish_reason": finish_reason, + "usage": usage, + } + + +def compute_request_metrics( + t_start: float, + t_first_token: float, + token_times: list, + t_end: float, + prompt_tokens: int, + completion_tokens: int, +) -> dict: + """Compute standard latency and throughput metrics for a single request. + + All time arguments are :func:`time.perf_counter` values (seconds as + floats). + + Args: + t_start: Timestamp immediately before the request was sent. + t_first_token: Timestamp when the first content token was received. + token_times: List of timestamps, one per content token (including the + first). When there is only one token ``tpot_ms`` is ``0.0``. + t_end: Timestamp after the final SSE chunk was consumed. + prompt_tokens: Number of prompt tokens reported by the server. + completion_tokens: Number of completion tokens generated. + + Returns: + Dict with keys ``ttft_ms``, ``tpot_ms``, ``e2e_latency_ms``, + ``gen_tps``, ``prompt_tps`` — all floats. + """ + ttft_ms = (t_first_token - t_start) * 1000.0 + e2e_latency_ms = (t_end - t_start) * 1000.0 + + # TPOT: mean inter-token gap across all generated tokens. + if len(token_times) > 1: + intervals = [ + token_times[i] - token_times[i - 1] for i in range(1, len(token_times)) + ] + tpot_ms = statistics.mean(intervals) * 1000.0 + else: + tpot_ms = 0.0 + + # Use last token time (not t_end which includes HTTP teardown) + t_last_token = token_times[-1] if token_times else t_end + gen_duration = t_last_token - t_first_token + gen_tps = completion_tokens / gen_duration if gen_duration > 0 else 0.0 + + prompt_duration = t_first_token - t_start + prompt_tps = prompt_tokens / prompt_duration if prompt_duration > 0 else 0.0 + + return { + "ttft_ms": ttft_ms, + "tpot_ms": tpot_ms, + "e2e_latency_ms": e2e_latency_ms, + "gen_tps": gen_tps, + "prompt_tps": prompt_tps, + } + + +async def count_prompt_tokens( + client: httpx.AsyncClient, + base_url: str, + messages: list[dict], + model: str, +) -> int: + """Count prompt tokens for a message list by sending a 1-token request. + + Sends a non-streaming chat completion with ``max_tokens=1`` and reads + ``usage.prompt_tokens`` from the response. + + Args: + client: An open :class:`httpx.AsyncClient`. + base_url: Base URL of the server. + messages: The message list to send. + model: Model ID to target. + + Returns: + Number of prompt tokens, or ``0`` on error. + """ + try: + resp = await client.post( + f"{base_url}/v1/chat/completions", + json={ + "model": model, + "messages": messages, + "max_tokens": 1, + "stream": False, + }, + ) + resp.raise_for_status() + data = resp.json() + return int((data.get("usage") or {}).get("prompt_tokens", 0)) + except Exception: + return 0 + + +async def stream_chat_completion( + client: httpx.AsyncClient, + base_url: str, + messages: list[dict], + model: str, + max_tokens: int = 256, + enable_thinking: Optional[bool] = None, + extra_body: Optional[dict] = None, +) -> dict: + """Send a streaming chat completion and collect per-token timing data. + + Tracks TTFT, per-token timestamps, accumulated content, finish reason, + and usage (via ``stream_options: {"include_usage": True}``). + + Args: + client: An open :class:`httpx.AsyncClient`. + base_url: Base URL of the server. + messages: The message list to send. + model: Model ID to target. + max_tokens: Maximum tokens to generate (default ``256``). + enable_thinking: If not ``None``, passed as ``enable_thinking`` in the + request body. + extra_body: Optional extra keys merged into the request body. + + Returns: + Dict with all :func:`compute_request_metrics` fields plus + ``completion_tokens``, ``prompt_tokens``, ``finish_reason``, + ``content``. + """ + body: dict = { + "model": model, + "messages": messages, + "max_tokens": max_tokens, + "stream": True, + "stream_options": {"include_usage": True}, + } + if enable_thinking is not None: + body["enable_thinking"] = enable_thinking + if extra_body: + body.update(extra_body) + + t_start = time.perf_counter() + t_first_token: Optional[float] = None + token_times: list[float] = [] + content_parts: list[str] = [] + finish_reason: Optional[str] = None + usage: Optional[dict] = None + + async with client.stream( + "POST", f"{base_url}/v1/chat/completions", json=body + ) as response: + response.raise_for_status() + async for raw_line in response.aiter_lines(): + parsed = parse_sse_line(raw_line) + if parsed is None: + continue + if parsed.get("usage"): + usage = parsed["usage"] + if parsed.get("finish_reason"): + finish_reason = parsed["finish_reason"] + chunk_content = parsed.get("content", "") + if chunk_content: + now = time.perf_counter() + if t_first_token is None: + t_first_token = now + token_times.append(now) + content_parts.append(chunk_content) + + t_end = time.perf_counter() + if t_first_token is None: + t_first_token = t_end + + prompt_tokens = int((usage or {}).get("prompt_tokens", 0)) + completion_tokens = int((usage or {}).get("completion_tokens", 0)) + + metrics = compute_request_metrics( + t_start=t_start, + t_first_token=t_first_token, + token_times=token_times, + t_end=t_end, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + + return { + **metrics, + "completion_tokens": completion_tokens, + "prompt_tokens": prompt_tokens, + "finish_reason": finish_reason, + "content": "".join(content_parts), + } + + +# --------------------------------------------------------------------------- +# Task 5: Concurrent execution + validation + summary statistics +# --------------------------------------------------------------------------- + + +def validate_response( + finish_reason: Optional[str], + content: str, + status_code: int, +) -> tuple[bool, str]: + """Validate a single streaming response result. + + Args: + finish_reason: The ``finish_reason`` from the final SSE chunk, or + ``None`` if not received. + content: The accumulated text content of the response. + status_code: The HTTP status code of the response (use ``200`` for + successful streaming requests). + + Returns: + ``(is_valid, message)`` — ``is_valid`` is ``True`` when the response + passes all checks; ``message`` is an empty string on success or a + human-readable description of the first failure. + """ + if status_code >= 400: + return (False, f"HTTP error {status_code}") + if finish_reason is None: + return (False, "Missing finish_reason") + if finish_reason == "length": + return (False, "Truncated (finish_reason=length)") + if not content: + return (False, "Empty response content") + return (True, "") + + +def compute_summary_stats(values: list[float]) -> dict: + """Compute summary statistics over a list of floats. + + Args: + values: Non-empty list of floats to summarise. + + Returns: + Dict with keys ``mean``, ``stddev``, ``min``, ``max``, ``p50``, + ``p95``, ``p99``. Percentiles use linear interpolation on sorted + values. + + Raises: + ValueError: If ``values`` is empty. + """ + if not values: + raise ValueError("Cannot compute summary stats on empty list") + + n = len(values) + mean = statistics.mean(values) + stddev = 0.0 if n == 1 else statistics.stdev(values) + sorted_vals = sorted(values) + + def _percentile(p: float) -> float: + if n == 1: + return sorted_vals[0] + # Linear interpolation: index = p/100 * (n-1) + idx = p / 100.0 * (n - 1) + lo = int(idx) + hi = lo + 1 + if hi >= n: + return sorted_vals[-1] + frac = idx - lo + return sorted_vals[lo] + frac * (sorted_vals[hi] - sorted_vals[lo]) + + return { + "mean": mean, + "stddev": stddev, + "min": sorted_vals[0], + "max": sorted_vals[-1], + "p50": _percentile(50), + "p95": _percentile(95), + "p99": _percentile(99), + } + + +async def run_concurrent_requests( + client: httpx.AsyncClient, + base_url: str, + prompts: list[list[dict]], + model: str, + concurrency: int, + max_tokens: int = 256, + enable_thinking: Optional[bool] = None, + extra_body: Optional[dict] = None, + do_validate: bool = True, +) -> list[dict]: + """Fire ``concurrency`` concurrent streaming requests and collect results. + + Prompts are selected round-robin from ``prompts``. All requests are + launched simultaneously with :func:`asyncio.gather`. Exceptions are + caught per-task and wrapped in an error dict rather than propagated. + + Args: + client: An open :class:`httpx.AsyncClient`. + base_url: Base URL of the server. + prompts: List of message dicts to cycle through. + model: Model ID to target. + concurrency: Number of simultaneous requests to fire. + max_tokens: Maximum tokens to generate per request (default ``256``). + enable_thinking: Passed through to :func:`stream_chat_completion`. + extra_body: Passed through to :func:`stream_chat_completion`. + do_validate: When ``True``, call :func:`validate_response` on each + result and add a ``"validated"`` key. + + Returns: + List of result dicts (one per request). Each dict contains at minimum + a ``"validated"`` key when ``do_validate`` is ``True``. + """ + prompt_cycle = itertools.cycle(prompts) + selected = [next(prompt_cycle) for _ in range(concurrency)] + + async def _single(messages: list[dict]) -> dict: + try: + result = await stream_chat_completion( + client=client, + base_url=base_url, + messages=messages, + model=model, + max_tokens=max_tokens, + enable_thinking=enable_thinking, + extra_body=extra_body, + ) + if do_validate: + is_valid, _ = validate_response( + finish_reason=result.get("finish_reason"), + content=result.get("content", ""), + status_code=200, + ) + result["validated"] = is_valid + return result + except Exception as exc: + err: dict = { + "error": str(exc), + "validated": False, + } + return err + + results = await asyncio.gather(*[_single(msg) for msg in selected]) + return list(results) + + +# --------------------------------------------------------------------------- +# Task 6: Output formatters +# --------------------------------------------------------------------------- + +RESULT_COLUMNS: list[str] = [f.name for f in _dataclasses.fields(BenchServeResult)] + +_TABLE_COLUMNS = [ + "prompt_set", + "concurrency", + "prompt_tokens", + "ttft_ms", + "tpot_ms", + "gen_tps", + "prompt_tps", + "e2e_latency_ms", + "validated", +] + + +def _result_to_dict(r: BenchServeResult) -> dict: + """Convert a :class:`BenchServeResult` to an ordered dict. + + Returns an ``OrderedDict``-style plain ``dict`` whose keys follow the + dataclass field declaration order (as listed in :data:`RESULT_COLUMNS`). + """ + return {f.name: getattr(r, f.name) for f in _dataclasses.fields(r)} + + +def format_table(results: list[BenchServeResult]) -> str: + """Render a human-readable terminal table of benchmark results. + + Only the columns in :data:`_TABLE_COLUMNS` are shown. Float values are + rounded to one decimal place. + + Args: + results: List of :class:`BenchServeResult` instances. + + Returns: + Formatted string using ``tabulate`` with ``tablefmt="simple"``. + """ + rows = [] + for r in results: + d = _result_to_dict(r) + row = [] + for col in _TABLE_COLUMNS: + val = d.get(col) + if isinstance(val, float): + val = round(val, 1) + row.append(val) + rows.append(row) + return _tabulate(rows, headers=_TABLE_COLUMNS, tablefmt="simple") + + +def format_json(results: list[BenchServeResult]) -> str: + """Serialize benchmark results as a JSON array. + + All fields from :data:`RESULT_COLUMNS` are included. + + Args: + results: List of :class:`BenchServeResult` instances. + + Returns: + JSON string with ``indent=2``. + """ + return json.dumps([_result_to_dict(r) for r in results], indent=2) + + +def format_csv(results: list[BenchServeResult]) -> str: + """Serialize benchmark results as CSV with a header row. + + All columns are included. + + Args: + results: List of :class:`BenchServeResult` instances. + + Returns: + CSV string (header + one row per result). + """ + buf = io.StringIO() + writer = csv_mod.DictWriter(buf, fieldnames=RESULT_COLUMNS) + writer.writeheader() + for r in results: + writer.writerow(_result_to_dict(r)) + return buf.getvalue() + + +def _sql_escape(value) -> str: + """Escape a Python value for use as a SQL literal. + + - ``None`` -> ``"NULL"`` + - ``bool`` -> ``"1"`` or ``"0"`` + - ``int`` / ``float`` -> string representation + - ``str`` -> single-quoted with internal single-quotes doubled + """ + if value is None: + return "NULL" + if isinstance(value, bool): + return "1" if value else "0" + if isinstance(value, float): + if math.isnan(value) or math.isinf(value): + return "NULL" + return str(value) + if isinstance(value, int): + return str(value) + # str + escaped = str(value).replace("'", "''") + return f"'{escaped}'" + + +_SQL_SCHEMA = ( + "run_id TEXT, timestamp TEXT, tag TEXT, " + "chip TEXT, gpu_cores INTEGER, memory_gb REAL, bandwidth_gbs REAL, os_version TEXT, " + "model_id TEXT, model_type TEXT, engine_type TEXT, mtp_enabled BOOLEAN, " + "specprefill BOOLEAN, kv_quant TEXT, cache_type TEXT, " + "prompt_set TEXT, concurrency INTEGER, max_tokens INTEGER, enable_thinking BOOLEAN, " + "extra_body TEXT, repetition INTEGER, prompt_tokens INTEGER, " + "ttft_ms REAL, tpot_ms REAL, e2e_latency_ms REAL, " + "gen_tps REAL, prompt_tps REAL, throughput_tps REAL, requests_per_s REAL, " + "metal_active_gb REAL, metal_peak_gb REAL, metal_cache_gb REAL, " + "cache_hits INTEGER, cache_misses INTEGER, cache_hit_rate REAL, tokens_saved INTEGER, " + "validated BOOLEAN" +) + + +def format_sql(results: list[BenchServeResult]) -> str: + """Emit a SQL ``CREATE TABLE IF NOT EXISTS`` statement and INSERT rows. + + The schema follows the exact column order defined in the bench-serve spec. + + Args: + results: List of :class:`BenchServeResult` instances. + + Returns: + SQL string containing the CREATE TABLE statement followed by one + INSERT statement per result. + """ + lines = [ + f"CREATE TABLE IF NOT EXISTS bench_serve ({_SQL_SCHEMA});", + ] + for r in results: + d = _result_to_dict(r) + values = ", ".join(_sql_escape(d[col]) for col in RESULT_COLUMNS) + lines.append(f"INSERT INTO bench_serve VALUES ({values});") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Task 7: Main async orchestrator +# --------------------------------------------------------------------------- + +logger = logging.getLogger(__name__) + + +async def run_bench_serve( + url: str = "http://127.0.0.1:8080", + model: Optional[str] = None, + prompt_sets: list[str] = None, + prompt_file: Optional[str] = None, + concurrencies: list[int] = None, + max_tokens: int = 256, + repetitions: int = 3, + warmup: int = 1, + thinking_values: list[Optional[bool]] = None, + extra_bodies: list[str] = None, + output_path: Optional[str] = None, + fmt: str = "table", + do_validate: bool = True, + scrape: bool = True, + tag: Optional[str] = None, + override_fields: Optional[dict] = None, + system_prompt_file: Optional[str] = None, + skip_preflight_token_count: bool = False, +) -> list[BenchServeResult]: + """Run the full bench-serve sweep against a running vllm-mlx server. + + Args: + url: Base URL of the server. + model: Model ID to use. If ``None``, auto-detected from the server. + prompt_sets: List of prompt set names or paths. Defaults to + ``["short", "medium", "long"]``. + prompt_file: Optional path to an extra prompt file to include. + concurrencies: Concurrency levels to sweep. Defaults to ``[1, 4]``. + max_tokens: Maximum tokens to generate per request. + repetitions: Number of repetitions per sweep config. + warmup: Number of warmup rounds before the first measured repetition. + thinking_values: Values for ``enable_thinking``. Defaults to + ``[None]``. + extra_bodies: JSON strings for extra body parameters. Defaults to + ``[""]`` (no extra body). + output_path: File path to write results to. If ``None``, prints to + stdout. + fmt: Output format — one of ``"table"``, ``"json"``, ``"csv"``, + ``"sql"``. + do_validate: Whether to validate each response. + scrape: Whether to scrape ``/metrics`` before and after each run. + tag: Optional tag string stored in every result row. + override_fields: Dict of field names to override on every result. + + Returns: + List of :class:`BenchServeResult` instances. + """ + # 1. Set defaults + if prompt_sets is None: + prompt_sets = ["short", "medium", "long"] + if concurrencies is None: + concurrencies = [1, 4] + if thinking_values is None: + thinking_values = [None] + if extra_bodies is None: + extra_bodies = [""] + if override_fields is None: + override_fields = {} + + # 2. Generate run_id and timestamp + run_id = str(uuid.uuid4())[:8] + timestamp = datetime.now(timezone.utc).isoformat() + + # 3. Open HTTP client + async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client: + # 4. Auto-detect runtime and hardware + print(f"Connecting to {url}...") + runtime = await auto_detect_runtime(client, url) + hw = detect_hardware_fingerprint() + + # 5. Resolve model_id + model_id = model or runtime.get("model_id", "") + if not model_id: + print( + "Error: could not determine model ID. Use --model to specify.", + file=sys.stderr, + ) + return [] + + # 6. Print hardware and runtime info + print( + f"Hardware: {hw.get('chip', 'unknown')} / {hw.get('memory_gb', 0):.0f}GB / {hw.get('os_version', '')}" + ) + print( + f"Runtime: model={model_id} engine={runtime.get('engine_type', '')} cache={runtime.get('cache_type', '')}" + ) + + # 7. Load prompts + all_prompts: dict[str, list[list[dict]]] = {} + for ps in prompt_sets: + try: + all_prompts[ps] = load_prompt_set(ps) + except FileNotFoundError as exc: + print(f"Warning: skipping prompt set '{ps}': {exc}", file=sys.stderr) + if prompt_file: + try: + all_prompts[prompt_file] = load_prompt_set(prompt_file) + except FileNotFoundError as exc: + print( + f"Warning: skipping prompt file '{prompt_file}': {exc}", + file=sys.stderr, + ) + + if not all_prompts: + print("Error: no prompt sets could be loaded.", file=sys.stderr) + return [] + + # 7b. If --system-prompt-file given, prepend that system message to + # every prompt across every set. This is the warm-prompts path: the + # server was started with the same system in its warm-up file, so + # every request here hits the prefix cache. + if system_prompt_file: + sys_path = Path(system_prompt_file).expanduser() + if not sys_path.exists(): + print( + f"Error: --system-prompt-file not found: {sys_path}", + file=sys.stderr, + ) + return [] + sys_content = sys_path.read_text() + system_msg = {"role": "system", "content": sys_content} + for ps, prompts in all_prompts.items(): + patched: list[list[dict]] = [] + for msgs in prompts: + # Do not double-prepend if the prompt already has a system + # as its first message. + if msgs and msgs[0].get("role") == "system": + patched.append(msgs) + else: + patched.append([system_msg] + msgs) + all_prompts[ps] = patched + print(f"System prompt prepended from {sys_path} ({len(sys_content)} chars)") + + # 8. Token-count first prompt from each set. + # This sends a non-streaming max_tokens=1 request per prompt set, which + # populates the server's prefix cache with the full prompt. That is + # harmless for ordinary benchmarking but DEFEATS cold-vs-warm + # comparisons (both paths end up warm after the pre-flight). Skip it + # when --skip-preflight-token-count is set; prompt_tokens will be + # populated from the first measured request's usage instead. + prompt_token_counts: dict[str, int] = {} + if skip_preflight_token_count: + for ps in all_prompts: + prompt_token_counts[ps] = 0 + else: + for ps, prompts in all_prompts.items(): + try: + count = await count_prompt_tokens(client, url, prompts[0], model_id) + prompt_token_counts[ps] = count + except Exception: + prompt_token_counts[ps] = 0 + + # 9. Expand sweep + sweep = expand_sweep( + list(all_prompts.keys()), + concurrencies, + thinking_values, + extra_bodies, + repetitions, + ) + + # Account for warmup rounds: insert warmup configs at rep==0 boundaries. + # We handle warmup inline during the sweep by tracking which + # (ps, conc, think, eb) combos have been warmed up. + total_runs = len(sweep) + warmup_note = f" (+ {warmup} warmup per config)" if warmup > 0 else "" + print(f"Total runs: {total_runs}{warmup_note}") + + results: list[BenchServeResult] = [] + warmed_up: set[tuple] = set() + + # 11. Iterate over sweep + for ps, conc, think, eb, rep in sweep: + prompts = all_prompts[ps] + + # Parse extra_body + extra_body_dict: Optional[dict] = None + if eb: + try: + extra_body_dict = json.loads(eb) + except json.JSONDecodeError: + extra_body_dict = None + + # Format progress label + think_label = f"think={think}" if think is not None else "" + eb_label = f"eb={eb[:20]}" if eb else "" + label_parts = [ps, f"conc={conc}", f"rep={rep}"] + if think_label: + label_parts.append(think_label) + if eb_label: + label_parts.append(eb_label) + label = " ".join(label_parts) + + # d. Warmup on first rep of each unique config + warmup_key = (ps, conc, think, eb) + if rep == 0 and warmup_key not in warmed_up and warmup > 0: + warmed_up.add(warmup_key) + for _ in range(warmup): + try: + await run_concurrent_requests( + client=client, + base_url=url, + prompts=prompts, + model=model_id, + concurrency=conc, + max_tokens=max_tokens, + enable_thinking=think, + extra_body=extra_body_dict, + do_validate=False, + ) + except Exception: + pass + + # e. Scrape /metrics before + metrics_before: dict = {} + if scrape: + metrics_before = await scrape_metrics(client, url) + + # f. Run concurrent requests + req_results = await run_concurrent_requests( + client=client, + base_url=url, + prompts=prompts, + model=model_id, + concurrency=conc, + max_tokens=max_tokens, + enable_thinking=think, + extra_body=extra_body_dict, + do_validate=do_validate, + ) + + # g. Scrape /metrics after, compute cache delta + metrics_after: dict = {} + if scrape: + metrics_after = await scrape_metrics(client, url) + + cache_hits_delta = metrics_after.get("cache_hits", 0) - metrics_before.get( + "cache_hits", 0 + ) + cache_misses_delta = metrics_after.get( + "cache_misses", 0 + ) - metrics_before.get("cache_misses", 0) + tokens_saved_delta = metrics_after.get( + "tokens_saved", 0 + ) - metrics_before.get("tokens_saved", 0) + total_events = cache_hits_delta + cache_misses_delta + cache_hit_rate = ( + cache_hits_delta / total_events if total_events > 0 else 0.0 + ) + + # h. Get metal memory from /v1/status + metal_active_gb = runtime.get("metal_active_gb", 0.0) + metal_peak_gb = runtime.get("metal_peak_gb", 0.0) + metal_cache_gb = runtime.get("metal_cache_gb", 0.0) + try: + resp = await client.get(f"{url}/v1/status") + resp.raise_for_status() + status_data = parse_status_response(resp.json()) + metal_active_gb = status_data.get("metal_active_gb", metal_active_gb) + metal_peak_gb = status_data.get("metal_peak_gb", metal_peak_gb) + metal_cache_gb = status_data.get("metal_cache_gb", metal_cache_gb) + except Exception: + pass + + # i. Aggregate per-request metrics + valid_results = [r for r in req_results if "error" not in r] + if not valid_results: + # All requests errored — build a failed result + result_obj = BenchServeResult( + run_id=run_id, + timestamp=timestamp, + tag=tag or "", + # Hardware + chip=hw.get("chip", ""), + gpu_cores=hw.get("gpu_cores", 0), + memory_gb=hw.get("memory_gb", 0.0), + bandwidth_gbs=hw.get("bandwidth_gbs", 0.0), + os_version=hw.get("os_version", ""), + # Runtime + model_id=model_id, + model_type=runtime.get("model_type", ""), + engine_type=runtime.get("engine_type", ""), + mtp_enabled=runtime.get("mtp_enabled", False), + specprefill=runtime.get("specprefill", False), + kv_quant=runtime.get("kv_quant", ""), + cache_type=runtime.get("cache_type", ""), + # Config + prompt_set=ps, + concurrency=conc, + max_tokens=max_tokens, + enable_thinking=think, + extra_body=eb, + repetition=rep, + prompt_tokens=prompt_token_counts.get(ps, 0), + # Latency / throughput all zero + validated=False, + ) + print(f" {label}: FAIL (all requests errored)") + else: + + def _mean(key: str) -> float: + vals = [ + r[key] for r in valid_results if key in r and r[key] is not None + ] + return statistics.mean(vals) if vals else 0.0 + + mean_ttft = _mean("ttft_ms") + mean_tpot = _mean("tpot_ms") + mean_gen_tps = _mean("gen_tps") + mean_prompt_tps = _mean("prompt_tps") + mean_e2e = _mean("e2e_latency_ms") + + total_completion_tokens = sum( + r.get("completion_tokens", 0) for r in valid_results + ) + max_e2e_seconds = ( + max( + (r.get("e2e_latency_ms", 0.0) for r in valid_results), + default=0.0, + ) + / 1000.0 + ) + throughput_tps = ( + total_completion_tokens / max_e2e_seconds + if max_e2e_seconds > 0 + else 0.0 + ) + requests_per_s = conc / max_e2e_seconds if max_e2e_seconds > 0 else 0.0 + + all_validated = all(r.get("validated", True) for r in valid_results) + + result_obj = BenchServeResult( + run_id=run_id, + timestamp=timestamp, + tag=tag or "", + # Hardware + chip=hw.get("chip", ""), + gpu_cores=hw.get("gpu_cores", 0), + memory_gb=hw.get("memory_gb", 0.0), + bandwidth_gbs=hw.get("bandwidth_gbs", 0.0), + os_version=hw.get("os_version", ""), + # Runtime + model_id=model_id, + model_type=runtime.get("model_type", ""), + engine_type=runtime.get("engine_type", ""), + mtp_enabled=runtime.get("mtp_enabled", False), + specprefill=runtime.get("specprefill", False), + kv_quant=runtime.get("kv_quant", ""), + cache_type=runtime.get("cache_type", ""), + # Config + prompt_set=ps, + concurrency=conc, + max_tokens=max_tokens, + enable_thinking=think, + extra_body=eb, + repetition=rep, + prompt_tokens=prompt_token_counts.get(ps, 0), + # Latency + ttft_ms=mean_ttft, + tpot_ms=mean_tpot, + e2e_latency_ms=mean_e2e, + # Throughput + gen_tps=mean_gen_tps, + prompt_tps=mean_prompt_tps, + throughput_tps=throughput_tps, + requests_per_s=requests_per_s, + # Memory + metal_active_gb=metal_active_gb, + metal_peak_gb=metal_peak_gb, + metal_cache_gb=metal_cache_gb, + # Cache + cache_hits=cache_hits_delta, + cache_misses=cache_misses_delta, + cache_hit_rate=cache_hit_rate, + tokens_saved=tokens_saved_delta, + # Validation + validated=all_validated, + ) + + status = "PASS" if all_validated else "FAIL" + print( + f" {label}: TTFT={mean_ttft:.0f}ms TPS={mean_gen_tps:.1f} {status}" + ) + + # j. Apply override_fields + for field_name, field_val in override_fields.items(): + if hasattr(result_obj, field_name): + setattr(result_obj, field_name, field_val) + + results.append(result_obj) + + # 12. Format output + formatters = { + "table": format_table, + "json": format_json, + "csv": format_csv, + "sql": format_sql, + } + formatter = formatters.get(fmt, format_table) + output = formatter(results) + + # 13. Write to file or stdout + if output_path: + Path(output_path).write_text(output) + print(f"\nResults written to {output_path}") + else: + print() + print(output) + + return results diff --git a/vllm_mlx/bench_serve_prompts/long.json b/vllm_mlx/bench_serve_prompts/long.json new file mode 100644 index 000000000..f912a107b --- /dev/null +++ b/vllm_mlx/bench_serve_prompts/long.json @@ -0,0 +1,14 @@ +[ + { + "role": "user", + "content": "You are a senior site reliability engineer. Analyze the following server logs from a production incident that occurred between 03:42 and 04:15 UTC. Identify the root cause, the cascade failure sequence, and write a detailed post-mortem with timeline, impact assessment, and remediation plan.\n\nServer logs:\n```\n2026-04-17T03:42:01.123Z [INFO] web-01 nginx: upstream response time 0.142s for /api/v2/inference\n2026-04-17T03:42:03.445Z [INFO] web-01 nginx: upstream response time 0.189s for /api/v2/inference\n2026-04-17T03:42:11.002Z [WARN] gpu-01 vllm: KV cache utilization at 78% (threshold: 80%)\n2026-04-17T03:42:15.334Z [INFO] web-02 nginx: upstream response time 0.201s for /api/v2/inference\n2026-04-17T03:42:22.891Z [WARN] gpu-01 vllm: KV cache utilization at 82% — preemption starting\n2026-04-17T03:42:23.001Z [INFO] gpu-01 vllm: preempting request req-4829 (priority: low, tokens: 12847)\n2026-04-17T03:42:23.445Z [INFO] gpu-01 vllm: preempting request req-4831 (priority: low, tokens: 9234)\n2026-04-17T03:42:25.112Z [WARN] gpu-02 vllm: KV cache utilization at 79%\n2026-04-17T03:42:30.002Z [ERROR] lb-01 haproxy: backend gpu-01 health check failed (timeout 5000ms)\n2026-04-17T03:42:30.223Z [WARN] lb-01 haproxy: removing gpu-01 from pool (active connections: 47)\n2026-04-17T03:42:30.889Z [INFO] lb-01 haproxy: redistributing 47 connections to gpu-02, gpu-03\n2026-04-17T03:42:31.002Z [WARN] gpu-02 vllm: KV cache utilization at 91% — preemption starting\n2026-04-17T03:42:31.445Z [WARN] gpu-03 vllm: KV cache utilization at 85% — preemption starting\n2026-04-17T03:42:32.001Z [ERROR] gpu-02 vllm: OOM in KV cache — dropping batch of 23 requests\n2026-04-17T03:42:32.334Z [ERROR] gpu-03 vllm: OOM in KV cache — dropping batch of 18 requests\n2026-04-17T03:42:33.001Z [ERROR] lb-01 haproxy: backend gpu-02 health check failed\n2026-04-17T03:42:33.223Z [ERROR] lb-01 haproxy: backend gpu-03 health check failed\n2026-04-17T03:42:33.445Z [CRITICAL] lb-01 haproxy: NO HEALTHY BACKENDS — service unavailable\n2026-04-17T03:42:33.890Z [ERROR] web-01 nginx: 502 Bad Gateway for 100% of requests\n2026-04-17T03:42:34.112Z [ERROR] web-02 nginx: 502 Bad Gateway for 100% of requests\n2026-04-17T03:42:34.334Z [ALERT] pagerduty: P1 incident created — inference service down (incident-7821)\n2026-04-17T03:43:01.001Z [INFO] gpu-01 vllm: health check passing — KV cache at 45% after preemption\n2026-04-17T03:43:01.223Z [INFO] lb-01 haproxy: adding gpu-01 back to pool\n2026-04-17T03:43:02.001Z [INFO] gpu-02 vllm: restarting after OOM crash\n2026-04-17T03:43:02.334Z [INFO] gpu-03 vllm: restarting after OOM crash\n2026-04-17T03:43:15.001Z [INFO] lb-01 haproxy: gpu-01 serving 100% of traffic (load: 94%)\n2026-04-17T03:43:22.001Z [WARN] gpu-01 vllm: KV cache utilization at 83% — preemption starting\n2026-04-17T03:43:23.001Z [ERROR] lb-01 haproxy: backend gpu-01 health check failed again\n2026-04-17T03:43:23.334Z [CRITICAL] lb-01 haproxy: NO HEALTHY BACKENDS — service unavailable (second outage)\n2026-04-17T03:44:01.001Z [INFO] on-call-eng: manual intervention — enabling request throttling at lb-01\n2026-04-17T03:44:05.001Z [INFO] lb-01 haproxy: rate limiting to 200 req/s (down from 1200 req/s)\n2026-04-17T03:44:10.001Z [INFO] gpu-02 vllm: model loaded, joining pool\n2026-04-17T03:44:12.001Z [INFO] gpu-03 vllm: model loaded, joining pool\n2026-04-17T03:44:15.001Z [INFO] lb-01 haproxy: 3 healthy backends, distributing load\n2026-04-17T03:44:20.001Z [INFO] gpu-01 vllm: KV cache at 61%, stable\n2026-04-17T03:44:20.334Z [INFO] gpu-02 vllm: KV cache at 58%, stable\n2026-04-17T03:44:20.445Z [INFO] gpu-03 vllm: KV cache at 55%, stable\n2026-04-17T03:44:45.001Z [INFO] lb-01 haproxy: gradually increasing rate limit to 400 req/s\n2026-04-17T03:45:30.001Z [INFO] lb-01 haproxy: rate limit at 600 req/s, all backends stable\n2026-04-17T03:47:00.001Z [INFO] lb-01 haproxy: rate limit at 900 req/s, KV cache stable at 70-72%\n2026-04-17T03:50:00.001Z [INFO] lb-01 haproxy: rate limit removed, operating normally at 1150 req/s\n2026-04-17T03:50:05.001Z [INFO] pagerduty: incident-7821 resolved after 7m 31s\n2026-04-17T04:00:00.001Z [INFO] monitoring: error rate back to baseline (0.02%)\n```\n\nAdditional context:\n- Peak traffic started at 03:40 UTC (promotional campaign began)\n- Each GPU node has 80GB VRAM, models load at ~45GB\n- Health check timeout is 5000ms, configured conservatively to avoid flapping\n- The haproxy health check calls /health endpoint which blocks on vllm internal locks during preemption\n- Normal operating throughput: 400-600 req/s per GPU node\n- The load balancer uses round-robin with no connection-count awareness\n- Model restart time: approximately 70 seconds to reload weights\n- There are 4 GPU nodes total but gpu-04 was offline for scheduled maintenance\n\nWrite:\n1. Executive summary (3-4 sentences, non-technical audience)\n2. Detailed incident timeline with annotations explaining what each event means technically\n3. Root cause analysis — identify the PRIMARY root cause and up to 3 contributing factors\n4. Cascade failure analysis — explain the mechanism by which gpu-01 failure caused total outage\n5. Five-whys analysis for the primary root cause\n6. Impact assessment: estimated requests dropped, duration, affected users\n7. Immediate remediation actions taken and their effectiveness\n8. 10-item action plan to prevent recurrence, categorized by: (a) immediate/this week, (b) short-term/this quarter, (c) long-term/architectural\n9. Metrics and alerting gaps identified from the logs\n10. What a graceful degradation response would have looked like instead" + }, + { + "role": "user", + "content": "You are reviewing a distributed system for a real-time collaborative text editor (similar to Google Docs). The system uses Operational Transformation (OT) for conflict resolution. Below is the complete server-side implementation of the OT engine and WebSocket handler. Perform a thorough architectural review, identify correctness bugs, performance bottlenecks, and scalability limits.\n\n```python\n# ot_engine.py — Operational Transformation Engine\nfrom dataclasses import dataclass, field\nfrom typing import List, Optional, Dict, Any, Tuple\nimport time\nimport asyncio\nimport json\nimport hashlib\nfrom collections import defaultdict\n\n@dataclass\nclass Operation:\n \"\"\"A single OT operation.\"\"\"\n op_type: str # 'insert', 'delete', 'retain'\n position: int\n content: str = ''\n length: int = 0\n client_id: str = ''\n revision: int = 0\n timestamp: float = field(default_factory=time.time)\n\n@dataclass\nclass Document:\n \"\"\"Server-side document state.\"\"\"\n doc_id: str\n content: str = ''\n revision: int = 0\n history: List[Operation] = field(default_factory=list)\n clients: Dict[str, Any] = field(default_factory=dict)\n lock: asyncio.Lock = field(default_factory=asyncio.Lock)\n\nclass OTEngine:\n def __init__(self):\n self.documents: Dict[str, Document] = {}\n \n def get_or_create_document(self, doc_id: str) -> Document:\n if doc_id not in self.documents:\n self.documents[doc_id] = Document(doc_id=doc_id)\n return self.documents[doc_id]\n \n def apply_operation(self, doc: Document, op: Operation) -> str:\n \"\"\"Apply operation to document content.\"\"\"\n content = doc.content\n if op.op_type == 'insert':\n new_content = content[:op.position] + op.content + content[op.position:]\n elif op.op_type == 'delete':\n new_content = content[:op.position] + content[op.position + op.length:]\n elif op.op_type == 'retain':\n new_content = content\n else:\n raise ValueError(f'Unknown op type: {op.op_type}')\n return new_content\n \n def transform(self, op1: Operation, op2: Operation) -> Tuple[Operation, Operation]:\n \"\"\"Transform op1 against op2, return transformed versions.\"\"\"\n op1_prime = Operation(**vars(op1))\n op2_prime = Operation(**vars(op2))\n \n if op1.op_type == 'insert' and op2.op_type == 'insert':\n if op1.position <= op2.position:\n op2_prime.position += len(op1.content)\n else:\n op1_prime.position += len(op2.content)\n \n elif op1.op_type == 'insert' and op2.op_type == 'delete':\n if op1.position <= op2.position:\n op2_prime.position += len(op1.content)\n elif op1.position > op2.position + op2.length:\n op1_prime.position -= op2.length\n else:\n op1_prime.position = op2.position\n \n elif op1.op_type == 'delete' and op2.op_type == 'insert':\n if op2.position <= op1.position:\n op1_prime.position += len(op2.content)\n elif op2.position > op1.position + op1.length:\n op2_prime.position -= op1.length\n else:\n op2_prime.position = op1.position\n \n elif op1.op_type == 'delete' and op2.op_type == 'delete':\n if op1.position + op1.length <= op2.position:\n op2_prime.position -= op1.length\n elif op2.position + op2.length <= op1.position:\n op1_prime.position -= op2.length\n else:\n # Overlapping deletes - handle overlap\n overlap_start = max(op1.position, op2.position)\n overlap_end = min(op1.position + op1.length, op2.position + op2.length)\n overlap = overlap_end - overlap_start\n op1_prime.length = max(0, op1.length - overlap)\n op2_prime.length = max(0, op2.length - overlap)\n \n return op1_prime, op2_prime\n \n async def submit_operation(self, doc_id: str, op: Operation) -> Dict:\n \"\"\"Submit an operation from a client.\"\"\"\n doc = self.get_or_create_document(doc_id)\n \n async with doc.lock:\n # Transform against all operations since client's revision\n ops_since = doc.history[op.revision:]\n \n current_op = op\n for hist_op in ops_since:\n current_op, _ = self.transform(current_op, hist_op)\n \n # Apply the transformed operation\n doc.content = self.apply_operation(doc, current_op)\n doc.revision += 1\n current_op.revision = doc.revision\n doc.history.append(current_op)\n \n return {\n 'success': True,\n 'revision': doc.revision,\n 'operation': vars(current_op),\n 'checksum': hashlib.md5(doc.content.encode()).hexdigest()\n }\n\n# websocket_handler.py\nimport websockets\nimport asyncio\nfrom typing import Set\n\nengine = OTEngine()\nconnected_clients: Dict[str, Set] = defaultdict(set)\n\nasync def handle_client(websocket, path):\n doc_id = path.strip('/')\n doc = engine.get_or_create_document(doc_id)\n connected_clients[doc_id].add(websocket)\n doc.clients[websocket] = {'joined_at': time.time()}\n \n try:\n # Send current document state\n await websocket.send(json.dumps({\n 'type': 'init',\n 'content': doc.content,\n 'revision': doc.revision\n }))\n \n async for message in websocket:\n data = json.loads(message)\n \n if data['type'] == 'operation':\n op = Operation(\n op_type=data['op_type'],\n position=data['position'],\n content=data.get('content', ''),\n length=data.get('length', 0),\n client_id=data['client_id'],\n revision=data['revision']\n )\n \n result = await engine.submit_operation(doc_id, op)\n \n # Broadcast to all clients\n broadcast_msg = json.dumps({\n 'type': 'operation',\n 'operation': result['operation'],\n 'revision': result['revision'],\n 'checksum': result['checksum']\n })\n \n disconnected = set()\n for client in connected_clients[doc_id]:\n try:\n await client.send(broadcast_msg)\n except websockets.exceptions.ConnectionClosed:\n disconnected.add(client)\n \n for client in disconnected:\n connected_clients[doc_id].discard(client)\n \n except websockets.exceptions.ConnectionClosed:\n pass\n finally:\n connected_clients[doc_id].discard(websocket)\n del doc.clients[websocket]\n\nstart_server = websockets.serve(handle_client, 'localhost', 8765)\n\nasio.get_event_loop().run_until_complete(start_server)\nasyncio.get_event_loop().run_forever()\n```\n\nProvide:\n1. Correctness analysis: identify at least 5 OT algorithm bugs that would cause document divergence\n2. Concurrency analysis: identify race conditions and lock-ordering problems\n3. Memory and scalability analysis: what breaks at 1000 concurrent documents? 100k?\n4. Security analysis: what can a malicious client do?\n5. Performance bottlenecks: where is the O(n) or O(n²) complexity hiding?\n6. Missing features that would be required for production (list at least 8)\n7. Alternative approaches: when would you choose CRDT (e.g., Yjs, Automerge) over OT, and vice versa?\n8. A corrected implementation of the transform() function for the delete-delete case\n9. Architecture diagram description for a production-grade version supporting 100k simultaneous editors across 5 regions\n10. Estimated engineering effort to make this production-ready, broken down by component" + }, + { + "role": "user", + "content": "You are a technical architect. Write a comprehensive technical specification for a new feature: real-time collaborative AI prompt engineering workspace. This is a web application where multiple users can simultaneously edit, version, test, and optimize LLM prompts together.\n\nFeature requirements gathered from user research:\n- Up to 50 simultaneous collaborators per workspace\n- Real-time collaborative editing of prompt templates with variable substitution syntax like {{variable_name}}\n- Version history with branching (like git branches, not just linear history)\n- Side-by-side A/B testing of two prompt variants against the same test cases\n- Automated evaluation: run prompts against a test suite and score outputs using LLM-as-judge\n- Prompt chaining: connect multiple prompts where output of one feeds into another\n- Model comparison: run the same prompt against multiple LLM providers (OpenAI, Anthropic, local models)\n- Role-based access: Owner, Editor, Commenter, Viewer\n- Real-time presence indicators (who is editing what)\n- Comment threads anchored to specific positions in the prompt text\n- Export to Python/TypeScript SDK code\n- Webhook notifications when evaluation scores drop below threshold\n- Prompt marketplace: share prompts publicly with usage analytics\n\nDeliverable: Write the technical specification including:\n\n1. **System overview** — high-level architecture with component descriptions\n2. **Data model** — complete entity definitions with fields, types, relationships, and constraints for: Workspace, PromptTemplate, PromptVersion, PromptBranch, Variable, TestCase, EvaluationRun, EvaluationResult, ModelProvider, PromptChain, Comment, User, Permission\n3. **API design** — RESTful endpoints for all CRUD operations plus WebSocket events for real-time collaboration. Include request/response shapes for the 5 most complex endpoints.\n4. **Collaborative editing protocol** — how simultaneous edits are handled (OT vs CRDT decision with justification), conflict resolution, presence protocol\n5. **Evaluation pipeline** — architecture for running async evaluations, LLM-as-judge implementation, scoring rubrics, comparison statistics\n6. **Branching model** — how prompt version branching works (data model, merge semantics, conflict detection)\n7. **Prompt chaining execution** — DAG representation, execution order, variable passing, error handling, circular dependency detection\n8. **Performance requirements** — latency targets, throughput requirements, caching strategy\n9. **Security model** — authentication, authorization matrix, prompt injection risks in the evaluation pipeline\n10. **Implementation roadmap** — phased delivery plan across 4 quarters with team size recommendations and dependency ordering\n11. **Open questions and risks** — at least 8 unresolved design decisions with trade-off analysis\n12. **Competitive analysis** — compare to PromptLayer, Langsmith, Promptfoo, and Helicone; identify where this spec is differentiated" + } +] diff --git a/vllm_mlx/bench_serve_prompts/medium.json b/vllm_mlx/bench_serve_prompts/medium.json new file mode 100644 index 000000000..f93a82dae --- /dev/null +++ b/vllm_mlx/bench_serve_prompts/medium.json @@ -0,0 +1,22 @@ +[ + { + "role": "user", + "content": "Please review the following Python service code and provide a detailed code review. Identify any bugs, performance issues, security concerns, and style problems. Suggest concrete improvements with code examples where appropriate.\n\n```python\nimport sqlite3\nimport hashlib\nimport time\nfrom flask import Flask, request, jsonify\n\napp = Flask(__name__)\n\nDB_PATH = '/tmp/users.db'\n\ndef get_db():\n conn = sqlite3.connect(DB_PATH)\n return conn\n\ndef init_db():\n conn = get_db()\n conn.execute('''\n CREATE TABLE IF NOT EXISTS users (\n id INTEGER PRIMARY KEY,\n username TEXT,\n password TEXT,\n email TEXT,\n created_at INTEGER\n )\n ''')\n conn.commit()\n\n@app.route('/login', methods=['POST'])\ndef login():\n data = request.json\n username = data['username']\n password = data['password']\n \n conn = get_db()\n query = f\"SELECT * FROM users WHERE username='{username}' AND password='{password}'\"\n result = conn.execute(query).fetchone()\n \n if result:\n return jsonify({'status': 'ok', 'user_id': result[0]})\n else:\n return jsonify({'status': 'error', 'message': 'Invalid credentials'})\n\n@app.route('/register', methods=['POST'])\ndef register():\n data = request.json\n username = data['username']\n password = hashlib.md5(data['password'].encode()).hexdigest()\n email = data['email']\n \n conn = get_db()\n conn.execute(f\"INSERT INTO users VALUES (NULL, '{username}', '{password}', '{email}', {int(time.time())})\")\n conn.commit()\n return jsonify({'status': 'registered'})\n\n@app.route('/users', methods=['GET'])\ndef list_users():\n conn = get_db()\n users = conn.execute('SELECT id, username, email FROM users').fetchall()\n return jsonify({'users': [{'id': u[0], 'username': u[1], 'email': u[2]} for u in users]})\n\nif __name__ == '__main__':\n init_db()\n app.run(debug=True, host='0.0.0.0', port=5000)\n```\n\nAddress the following areas specifically:\n1. Security vulnerabilities (SQL injection, password hashing, authentication)\n2. Resource management (database connections, connection pooling)\n3. Error handling and input validation\n4. API design and HTTP status codes\n5. Production readiness concerns\n6. Python best practices and type annotations\n\nProvide a refactored version of the most critical endpoint (login) that addresses all identified issues." + }, + { + "role": "user", + "content": "Design a distributed rate limiter for a multi-region API gateway that serves 500,000 requests per second globally. The system must enforce per-user rate limits of 1,000 requests per minute with burst allowance of 200 requests, guarantee eventual consistency within 100ms across regions, and handle region failover gracefully.\n\nRequirements:\n- Regions: us-east-1, eu-west-1, ap-southeast-1\n- Per-user limit: 1,000 req/min with burst of 200\n- Global aggregate limit: 500,000 req/s\n- Consistency window: 100ms\n- Failover: region must operate independently if inter-region link drops\n- Storage: Redis Cluster preferred, but must handle Redis unavailability\n\nProvide:\n1. Architecture diagram description (components, data flow, communication patterns)\n2. Algorithm choice and justification (token bucket vs sliding window vs fixed window vs leaky bucket)\n3. Data model for rate limit state (key schema, TTL strategy, data types)\n4. Synchronization protocol between regions (gossip, pub/sub, or direct sync)\n5. Failover behavior and local fallback strategy\n6. Implementation sketch of the core rate-limit check function in Python or Go\n7. Approximate memory footprint calculation for 10 million unique users\n8. Monitoring and alerting strategy (which metrics, thresholds, and alert conditions)\n\nConsider trade-offs carefully and justify your choices. Identify any CAP theorem constraints that apply." + }, + { + "role": "user", + "content": "Analyze the following sorting algorithm implementations and compare their time complexity, space complexity, cache behavior, and practical performance characteristics. Then recommend which to use for each of the four scenarios described at the end.\n\n```python\ndef quicksort_lomuto(arr: list, low: int, high: int) -> None:\n if low < high:\n pivot_idx = partition_lomuto(arr, low, high)\n quicksort_lomuto(arr, low, pivot_idx - 1)\n quicksort_lomuto(arr, pivot_idx + 1, high)\n\ndef partition_lomuto(arr: list, low: int, high: int) -> int:\n pivot = arr[high]\n i = low - 1\n for j in range(low, high):\n if arr[j] <= pivot:\n i += 1\n arr[i], arr[j] = arr[j], arr[i]\n arr[i + 1], arr[high] = arr[high], arr[i + 1]\n return i + 1\n\ndef mergesort(arr: list) -> list:\n if len(arr) <= 1:\n return arr\n mid = len(arr) // 2\n left = mergesort(arr[:mid])\n right = mergesort(arr[mid:])\n return merge(left, right)\n\ndef merge(left: list, right: list) -> list:\n result = []\n i = j = 0\n while i < len(left) and j < len(right):\n if left[i] <= right[j]:\n result.append(left[i])\n i += 1\n else:\n result.append(right[j])\n j += 1\n result.extend(left[i:])\n result.extend(right[j:])\n return result\n\ndef timsort_simplified(arr: list, run: int = 32) -> list:\n n = len(arr)\n for i in range(0, n, run):\n insertion_sort(arr, i, min(i + run - 1, n - 1))\n size = run\n while size < n:\n for left in range(0, n, 2 * size):\n mid = min(left + size - 1, n - 1)\n right = min(left + 2 * size - 1, n - 1)\n if mid < right:\n merge_inplace(arr, left, mid, right)\n size *= 2\n return arr\n```\n\nScenarios:\n1. Sorting 10 million integers on a machine with 16GB RAM where the integers already have ~80% of elements in sorted order\n2. Sorting 100,000 user records by last name (strings, ~20 chars average) where stability is required\n3. Real-time sorting of a continuous stream of stock prices (need to insert one element at a time and maintain order)\n4. Sorting 500 small arrays of 20-50 elements each in a tight loop (inner loop of a graph algorithm)\n\nFor each scenario: state your recommendation, explain why, quantify the expected performance difference where possible, and note any caveats." + }, + { + "role": "user", + "content": "You are building a Kubernetes operator for managing ML model deployments. The operator must handle the full lifecycle: provisioning GPU nodes, loading model weights, health checking, auto-scaling based on queue depth, and graceful shutdown with request draining.\n\nHere is the current CRD definition and controller reconcile loop stub:\n\n```go\n// ModelDeployment CRD\ntype ModelDeployment struct {\n metav1.TypeMeta `json:\",inline\"`\n metav1.ObjectMeta `json:\"metadata,omitempty\"`\n Spec ModelDeploymentSpec `json:\"spec,omitempty\"`\n Status ModelDeploymentStatus `json:\"status,omitempty\"`\n}\n\ntype ModelDeploymentSpec struct {\n ModelID string `json:\"modelId\"`\n Replicas int32 `json:\"replicas\"`\n GPUType string `json:\"gpuType\"`\n MaxTokens int `json:\"maxTokens\"`\n ScalePolicy ScalePolicySpec `json:\"scalePolicy\"`\n Resources ResourceSpec `json:\"resources\"`\n}\n\ntype ScalePolicySpec struct {\n MinReplicas int32 `json:\"minReplicas\"`\n MaxReplicas int32 `json:\"maxReplicas\"`\n TargetQueueDepth int32 `json:\"targetQueueDepth\"`\n ScaleUpCooldown int `json:\"scaleUpCooldown\"`\n ScaleDownCooldown int `json:\"scaleDownCooldown\"`\n}\n\nfunc (r *ModelDeploymentReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {\n var md mlv1.ModelDeployment\n if err := r.Get(ctx, req.NamespacedName, &md); err != nil {\n return ctrl.Result{}, client.IgnoreNotFound(err)\n }\n // TODO: implement reconcile logic\n return ctrl.Result{}, nil\n}\n```\n\nImplement the complete reconcile function including:\n1. State machine transitions (Pending -> Provisioning -> Loading -> Ready -> Scaling -> Draining -> Terminating)\n2. Idempotent node provisioning using node affinity and taints\n3. Model weight loading detection via readiness probe on /health/model endpoint\n4. HPA-equivalent logic reading from a custom queue-depth metric\n5. Graceful shutdown: stop accepting new requests, drain in-flight requests with 30s timeout, then terminate\n6. Proper use of finalizers for cleanup\n7. Status condition updates at each transition\n\nAlso identify three subtle bugs or design flaws in the existing CRD spec that would cause problems in production." + }, + { + "role": "user", + "content": "Conduct a detailed performance analysis of the following database query and its execution plan. Then rewrite the query and schema to achieve at least 10x improvement in query time for the described workload.\n\nSchema:\n```sql\nCREATE TABLE orders (\n id BIGINT PRIMARY KEY,\n user_id BIGINT NOT NULL,\n status VARCHAR(20) NOT NULL,\n created_at TIMESTAMP NOT NULL,\n updated_at TIMESTAMP NOT NULL,\n total_amount DECIMAL(10,2),\n shipping_address TEXT,\n metadata JSONB\n);\n\nCREATE TABLE order_items (\n id BIGINT PRIMARY KEY,\n order_id BIGINT REFERENCES orders(id),\n product_id BIGINT NOT NULL,\n quantity INT NOT NULL,\n unit_price DECIMAL(10,2) NOT NULL,\n discount DECIMAL(5,2) DEFAULT 0\n);\n\nCREATE TABLE products (\n id BIGINT PRIMARY KEY,\n name VARCHAR(200),\n category VARCHAR(50),\n supplier_id BIGINT,\n weight_kg DECIMAL(8,3)\n);\n```\n\nProblematic query (runs in ~8 seconds on 50M orders, 200M order_items, 2M products):\n```sql\nSELECT \n u.user_id,\n COUNT(DISTINCT o.id) as order_count,\n SUM(oi.quantity * oi.unit_price * (1 - oi.discount/100)) as total_revenue,\n AVG(oi.quantity * oi.unit_price * (1 - oi.discount/100)) as avg_order_value,\n STRING_AGG(DISTINCT p.category, ', ') as categories_purchased\nFROM orders o\nJOIN order_items oi ON o.id = oi.order_id\nJOIN products p ON oi.product_id = p.id\nJOIN (\n SELECT user_id FROM orders \n WHERE status = 'completed'\n GROUP BY user_id \n HAVING COUNT(*) > 5\n) u ON o.user_id = u.user_id\nWHERE o.created_at >= NOW() - INTERVAL '90 days'\n AND o.status IN ('completed', 'shipped')\nGROUP BY u.user_id\nHAVING SUM(oi.quantity * oi.unit_price) > 100\nORDER BY total_revenue DESC\nLIMIT 1000;\n```\n\nProvide:\n1. Analysis of what makes this query slow (access patterns, join order, missing indexes, expression evaluation)\n2. The EXPLAIN ANALYZE output you would expect to see and what each expensive node means\n3. Index strategy: which indexes to add, composite index column ordering rationale, partial index opportunities\n4. Query rewrite using CTEs, materialized views, or query restructuring\n5. Schema changes if warranted (partitioning strategy, denormalization opportunities)\n6. Estimated improvement per optimization and cumulative expected speedup" + } +] diff --git a/vllm_mlx/bench_serve_prompts/short.json b/vllm_mlx/bench_serve_prompts/short.json new file mode 100644 index 000000000..0f5e975f4 --- /dev/null +++ b/vllm_mlx/bench_serve_prompts/short.json @@ -0,0 +1,7 @@ +[ + {"role": "user", "content": "What are the three laws of thermodynamics? Explain each one briefly."}, + {"role": "user", "content": "Write a Python function that checks if a string is a palindrome. Include type hints."}, + {"role": "user", "content": "Compare and contrast TCP and UDP protocols. When would you use each?"}, + {"role": "user", "content": "Explain the difference between a stack and a queue data structure with examples."}, + {"role": "user", "content": "What causes tides on Earth? Describe the role of the Moon and Sun."} +] diff --git a/vllm_mlx/bench_serve_prompts/thinking.json b/vllm_mlx/bench_serve_prompts/thinking.json new file mode 100644 index 000000000..c92d8f9ad --- /dev/null +++ b/vllm_mlx/bench_serve_prompts/thinking.json @@ -0,0 +1,14 @@ +[ + { + "role": "user", + "content": "You have 12 balls that are identical in appearance. One ball is either heavier or lighter than the others — you do not know which. You have a balance scale (no weights) that you can use exactly 3 times. Each weighing tells you only whether the left side is heavier, the right side is heavier, or they are equal.\n\nDesign a strategy that is guaranteed to identify the odd ball AND determine whether it is heavier or lighter in exactly 3 weighings.\n\nRequirements:\n- Your strategy must work in the worst case — it cannot rely on getting lucky with early results.\n- Label the balls 1 through 12 so you can refer to them precisely.\n- Describe each weighing as a decision: what you put on each side depends on the results of previous weighings.\n- After describing the strategy, trace through 3 distinct example cases to show it working:\n (a) The odd ball is ball #7 and it is heavier\n (b) The odd ball is ball #3 and it is lighter\n (c) The odd ball is ball #12 and it is heavier\n- Explain why 3 weighings are sufficient for 12 balls (information-theoretic argument using log base 3).\n- Explain why 3 weighings would NOT be sufficient for 13 balls.\n\nThink step by step through the decision tree before committing to a strategy." + }, + { + "role": "user", + "content": "Solve Einstein's Riddle (also called the Zebra Puzzle) using systematic logical deduction. Show every deduction step explicitly — do not skip steps or jump to conclusions.\n\nThe puzzle:\nThere are 5 houses in a row, each painted a different color: Red, Green, White, Yellow, Blue.\nEach house is occupied by a person of a different nationality: English, Swedish, Danish, Norwegian, German.\nEach person drinks a different beverage: Tea, Coffee, Milk, Beer, Water.\nEach person smokes a different brand: Pall Mall, Dunhill, Blend, BlueMaster, Prince.\nEach person has a different pet: Dog, Cat, Bird, Horse, Fish.\n\nClues:\n1. The English person lives in the red house.\n2. The Swedish person has a dog.\n3. The Danish person drinks tea.\n4. The green house is directly to the left of the white house.\n5. The person in the green house drinks coffee.\n6. The person who smokes Pall Mall has a bird.\n7. The person in the yellow house smokes Dunhill.\n8. The person in the middle house (house 3) drinks milk.\n9. The Norwegian lives in the first house (house 1).\n10. The person who smokes Blend lives next to the person with the cat.\n11. The person with the horse lives next to the Dunhill smoker.\n12. The person who smokes BlueMaster drinks beer.\n13. The German smokes Prince.\n14. The Norwegian lives next to the blue house.\n15. The Blend smoker lives next to the person who drinks water.\n\nQuestion: Who has the fish?\n\nRules for your solution:\n- Number the houses 1 (leftmost) to 5 (rightmost).\n- Start from the most constrained clues and propagate constraints.\n- At each step, state which clue(s) you are using and what you can deduce.\n- Show the partial state of the grid after each significant deduction.\n- If you make an assumption that requires backtracking, say so explicitly.\n- Do not guess — every cell must be filled by logical necessity." + }, + { + "role": "user", + "content": "A farmer has a rectangular field. He knows the following facts:\n\n1. The perimeter of the field is 120 meters.\n2. The diagonal of the field is exactly 50 meters.\n3. He wants to divide the field into exactly 6 equal rectangular plots by making cuts parallel to the sides. He can make cuts parallel to the length, parallel to the width, or both.\n4. Each plot must have the same area.\n5. The cuts must go all the way from one side of the field to the opposite side (no partial cuts).\n\nPart A — Find the dimensions:\nFind the length and width of the field. Show your algebraic work step by step using the perimeter and diagonal constraints simultaneously. Verify your answer.\n\nPart B — Division strategies:\nFind ALL distinct ways to divide the field into exactly 6 equal rectangular plots using the constraint that cuts go edge-to-edge. For each strategy:\n- Describe the cut pattern (e.g., '2 cuts parallel to width, 1 cut parallel to length')\n- State the dimensions of each resulting plot\n- Calculate the perimeter of each individual plot\n- Determine which strategy maximizes the perimeter of each individual plot and which minimizes it\n\nPart C — Fencing cost:\nThe farmer wants to install internal fencing along all cuts. Fencing costs $45 per meter. For each division strategy:\n- Calculate the total length of internal fencing required\n- Calculate the total cost\n- Identify the cheapest strategy\n\nPart D — Constraint relaxation:\nIf the farmer could make cuts that go only partway across the field (but the 6 plots must still be equal rectangles and tile the full field without overlap or gaps), would any new division strategies become possible that weren't available before? Explain why or why not with a geometric argument.\n\nShow all arithmetic. Do not skip algebraic steps." + } +] diff --git a/vllm_mlx/bench_serve_prompts/warm_prompts_example.json b/vllm_mlx/bench_serve_prompts/warm_prompts_example.json new file mode 100644 index 000000000..bd288d837 --- /dev/null +++ b/vllm_mlx/bench_serve_prompts/warm_prompts_example.json @@ -0,0 +1,8 @@ +[ + [ + { + "role": "system", + "content": "You are a code assistant running inside a developer's terminal. Your purpose is to help with software engineering tasks: fixing bugs, writing tests, refactoring, reviewing diffs, explaining code, and designing small features.\n\n## General principles\n\n- Read files before proposing changes. Never guess at code behaviour.\n- Prefer small, focused edits over sweeping refactors.\n- Match the project's existing conventions: naming, indentation, import order, testing style.\n- Before writing production code, check whether an equivalent helper already exists.\n- Make sure non-trivial changes come with at least one test that would catch a regression.\n- Do not add error handling, fallbacks, or validation for scenarios that cannot happen. Trust internal callers; validate at trust boundaries only.\n- Default to writing no comments. Add one only when the why is non-obvious, when there is a hidden constraint, or when a future reader would be genuinely surprised.\n\n## Safety and security\n\n- Never generate or guess URLs unless you are confident they are legitimate.\n- Never commit or print secrets, private keys, or credentials.\n- Refuse destructive techniques, DoS, mass targeting, or detection evasion.\n- Dual-use tools (credential testing, exploit development) need clear authorised context: pentest engagements, CTF, research, defensive work.\n- Sanitize user-controlled input before interpolating into shell commands, SQL, HTML, regex, file paths.\n- Size-limit regular expressions and JSON schemas from user input to prevent ReDoS and resource exhaustion.\n\n## Executing actions\n\nConsider the reversibility and blast radius of every action. For local reversible actions you may proceed freely. For actions that are hard to reverse, affect shared systems, or could otherwise be risky, pause and confirm with the user first.\n\nExamples that require confirmation:\n- Deleting files, branches, database rows; dropping tables; `rm -rf`.\n- Force-pushing, `git reset --hard`, amending published commits, rebasing shared history.\n- Pushing code, creating or closing PRs, posting issues, sending Slack or email messages.\n- Uploading content to third-party services (pastebins, diagram renderers).\n- Modifying CI pipelines, shared infrastructure, permissions.\n\nWhen you encounter an obstacle, do not reach for destructive actions as a shortcut. Investigate root cause first. Treat unfamiliar files, branches, or configuration as possibly work-in-progress and ask before deleting.\n\n## Tools available\n\n- `Read` - read a file from the filesystem. Use before editing.\n- `Write` - create or fully overwrite a file.\n- `Edit` - apply a targeted string replacement in an existing file.\n- `Bash` - run a shell command. Prefer dedicated tools when one exists.\n- `Grep` - search file contents with ripgrep-compatible regex.\n- `Glob` - find files matching a pattern (e.g. `**/*.py`).\n- `Agent` - launch a specialised subagent for complex or parallel work.\n\n## Tool usage rules\n\n- Before editing a file you have not inspected, call `Read` first.\n- Prefer `Glob`, `Grep`, `Read` over `find`, `grep`, `cat` via `Bash`.\n- Do not `ls` to verify a file exists before `Read` - just `Read` it; the error is informative.\n- Use absolute file paths when the user gave one. Otherwise resolve from the current working directory.\n- For multi-step exploration that may need 3+ queries, dispatch an `Agent` subagent.\n- Run independent tool calls in parallel when they have no data dependency.\n\n## Output format\n\n- Responses are plain text rendered in a terminal. GitHub-flavoured markdown is supported.\n- Keep intra-session chatter minimal. State results, not process narration.\n- Use file_path:line_number when referring to specific code locations.\n- For diffs, show only the changed hunks with a line of surrounding context.\n- For commands, show the command, a short note on what it does, and the relevant part of the output.\n- Keep final responses under 100 words unless the task genuinely requires more.\n\n## Coding conventions\n\n- Python: follow PEP 8, use type hints on public functions, prefer f-strings, use `pathlib` over `os.path`, avoid mutable default arguments.\n- TypeScript: strict mode, prefer `type` over `interface` for data shapes, use `const` by default, async/await over `.then` chains.\n- Rust: prefer `?` over `unwrap`, use `thiserror` for library errors, `anyhow` for binaries, keep `unsafe` blocks small and commented.\n- SQL: write upper-case keywords, qualify column names in joins, prefer CTEs for readability, always specify columns in INSERT.\n- Shell: use `set -euo pipefail`, quote variable expansions, prefer `[[ ]]` over `[ ]`, avoid parsing `ls`.\n\n## Testing conventions\n\n- Write a failing test before the fix when debugging.\n- Each test covers one behaviour. Avoid stuffing unrelated assertions into one test.\n- Prefer integration tests for glue, unit tests for algorithms, fast smoke tests for wiring.\n- Do not test implementation details that make refactoring painful.\n- Use the project's existing fixtures and factories; do not invent parallel infrastructure.\n- Deterministic tests only: seed RNGs, freeze clocks, fake external IO.\n\n## Git workflow\n\n- Commit messages: short, informal, explain why when it is non-obvious.\n- One logical change per commit. Unrelated cleanup goes in a separate commit.\n- Never commit secrets, large binaries, or machine-specific paths.\n- Rebase rather than merge for feature branches off a fast-moving main.\n- Show diffs and commit messages to the user for approval before pushing.\n\n## Mistakes to avoid\n\n- Adding features, refactors, or abstractions beyond what the task requires.\n- Writing docstrings that restate the function signature.\n- Silencing failing tests by weakening assertions instead of fixing the underlying issue.\n- Introducing new dependencies without flagging the trade-off.\n- Catching `Exception` and swallowing it.\n- Returning `None` or `-1` as a silent error signal in new code.\n- Putting unit-test-only logic inside production modules.\n- Copy-pasting code instead of extracting a shared helper when 3+ copies exist.\n\n## How to collaborate\n\nThe user is an experienced engineer. They want accurate, concise, actionable help. They prefer:\n\n- Short messages with specific file paths and line numbers.\n- Direct recommendations, not a menu of options, unless a real trade-off exists.\n- Honest uncertainty over confident guessing.\n- Admission when you do not know, followed by a suggestion for how to find out.\n\nWhen a task is ambiguous, ask one sharp clarifying question rather than assuming. When the task is clear, execute. When you have reached the edge of your context, say so explicitly.\n\nYou are good at this. Keep momentum." + } + ] +] diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 36edd91d8..5ec677a27 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -15,6 +15,8 @@ import argparse import sys +from .cli_arg_types import make_json_object_arg_parser + def serve_command(args): """Start the OpenAI-compatible server.""" @@ -26,7 +28,6 @@ def serve_command(args): # Import unified server from . import server - from .scheduler import SchedulerConfig from .server import RateLimiter, app, load_model logger = logging.getLogger(__name__) @@ -43,12 +44,40 @@ def serve_command(args): "Error: --gpu-memory-utilization must be between 0.0 (exclusive) and 1.0 (inclusive)" ) sys.exit(1) + if args.max_tokens < 1: + print("Error: --max-tokens must be at least 1") + sys.exit(1) + max_request_tokens = getattr(args, "max_request_tokens", args.max_tokens) + trust_remote_code = getattr(args, "trust_remote_code", False) + if max_request_tokens < 1: + print("Error: --max-request-tokens must be at least 1") + sys.exit(1) + if args.max_tokens > max_request_tokens: + print("Error: --max-tokens cannot exceed --max-request-tokens") + sys.exit(1) + + # Validate --turbo-kv-bits capability before spinning up model/server + if args.turbo_kv_bits is not None: + if args.kv_cache_quantization: + print( + "Error: --turbo-kv-bits and --kv-cache-quantization are mutually " + "exclusive; pick one compression path" + ) + sys.exit(1) + + from .memory_cache import _check_turboquant_capability + + missing = _check_turboquant_capability() + if missing is not None: + print(f"Error: --turbo-kv-bits requires TurboQuant support: {missing}") + sys.exit(1) # Configure server security settings server._api_key = args.api_key server._default_timeout = args.timeout server._metrics_enabled = args.enable_metrics server._metrics.configure(enabled=args.enable_metrics) + server._max_request_tokens = max_request_tokens if args.rate_limit > 0: server._rate_limiter = RateLimiter( requests_per_minute=args.rate_limit, enabled=True @@ -67,6 +96,11 @@ def serve_command(args): server._default_temperature = args.default_temperature if args.default_top_p is not None: server._default_top_p = args.default_top_p + server._default_chat_template_kwargs = args.default_chat_template_kwargs + max_audio_upload_mb = getattr(args, "max_audio_upload_mb", 25) + max_tts_input_chars = getattr(args, "max_tts_input_chars", 4096) + server._max_audio_upload_bytes = max_audio_upload_mb * 1024 * 1024 + server._max_tts_input_chars = max_tts_input_chars # Configure reasoning parser if args.reasoning_parser: @@ -108,6 +142,14 @@ def serve_command(args): print(" Metrics: ENABLED (/metrics, unauthenticated)") else: print(" Metrics: DISABLED - Use --enable-metrics to expose /metrics") + if trust_remote_code: + print(" Remote code loading: ENABLED (--trust-remote-code)") + else: + print(" Remote code loading: DISABLED (default)") + if args.auto_unload_idle_seconds > 0: + print(f" Idle auto-unload: ENABLED ({args.auto_unload_idle_seconds:.0f}s)") + else: + print(" Idle auto-unload: DISABLED") if args.enable_auto_tool_choice: print(f" Tool calling: ENABLED (parser: {args.tool_call_parser})") else: @@ -116,6 +158,10 @@ def serve_command(args): print(f" Reasoning: ENABLED (parser: {args.reasoning_parser})") else: print(" Reasoning: Use --reasoning-parser to enable") + print( + f" Audio upload limit: {max_audio_upload_mb} MiB, " + f"TTS input limit: {max_tts_input_chars} chars" + ) print("=" * 60) # Pre-download model with retry/timeout @@ -133,8 +179,13 @@ def serve_command(args): is_mllm=is_mllm_model(args.model), ) - print(f"Loading model: {args.model}") + if args.lazy_load_model: + print(f"Registering model for lazy load: {args.model}") + print("Model will load on the first request.") + else: + print(f"Loading model: {args.model}") print(f"Default max tokens: {args.max_tokens}") + print(f"Max request tokens: {max_request_tokens}") # Store MCP config path for FastAPI startup if args.mcp_config: @@ -142,14 +193,24 @@ def serve_command(args): os.environ["VLLM_MLX_MCP_CONFIG"] = args.mcp_config # Pre-load embedding model if specified - if args.embedding_model: - print(f"Pre-loading embedding model: {args.embedding_model}") - server.load_embedding_model(args.embedding_model, lock=True) - print(f"Embedding model loaded: {args.embedding_model}") + embedding_model = getattr(args, "embedding_model", None) + if embedding_model: + print(f"Pre-loading embedding model: {embedding_model}") + server.load_embedding_model(embedding_model, lock=True) + print(f"Embedding model loaded: {embedding_model}") + + # Pre-load reranker model if specified + rerank_model = getattr(args, "rerank_model", None) + if rerank_model: + print(f"Pre-loading reranker model: {rerank_model}") + server.load_reranker_model(rerank_model, lock=True) + print(f"Reranker model loaded: {rerank_model}") # Build scheduler config for batched mode scheduler_config = None if args.continuous_batching: + from .scheduler import SchedulerConfig + # Handle prefix cache flags enable_prefix_cache = args.enable_prefix_cache and not args.disable_prefix_cache @@ -181,6 +242,11 @@ def serve_command(args): mllm_prefill_step_size=( args.mllm_prefill_step_size if args.mllm_prefill_step_size > 0 else None ), + # TurboQuant + turbo_kv_bits=args.turbo_kv_bits, + # SSD cache tiering + ssd_cache_dir=getattr(args, "ssd_cache_dir", None), + ssd_cache_max_gb=getattr(args, "ssd_cache_max_gb", 10.0), ) print("Mode: Continuous batching (for multiple concurrent users)") @@ -200,7 +266,9 @@ def serve_command(args): else f"{args.cache_memory_percent*100:.0f}% of RAM" ) print(f"Memory-aware cache: {cache_info}") - if args.kv_cache_quantization: + if args.turbo_kv_bits: + print(f"TurboQuant: {args.turbo_kv_bits}-bit prefix cache compression") + elif args.kv_cache_quantization: print( f"KV cache quantization: {args.kv_cache_quantization_bits}-bit, " f"group_size={args.kv_cache_quantization_group_size}" @@ -227,15 +295,20 @@ def serve_command(args): scheduler_config=scheduler_config, stream_interval=args.stream_interval if args.continuous_batching else 1, max_tokens=args.max_tokens, + max_request_tokens=max_request_tokens, force_mllm=getattr(args, "mllm", False), gpu_memory_utilization=args.gpu_memory_utilization, served_model_name=args.served_model_name, + trust_remote_code=trust_remote_code, mtp=args.enable_mtp, prefill_step_size=args.prefill_step_size, specprefill_enabled=args.specprefill, specprefill_threshold=args.specprefill_threshold, specprefill_keep_pct=args.specprefill_keep_pct, specprefill_draft_model=args.specprefill_draft_model, + warm_prompts_path=getattr(args, "warm_prompts", None), + auto_unload_idle_seconds=args.auto_unload_idle_seconds, + lazy_load_model=args.lazy_load_model, ) # Start server @@ -274,6 +347,22 @@ def bench_command(args): # Handle prefix cache flags enable_prefix_cache = args.enable_prefix_cache and not args.disable_prefix_cache + # Validate --turbo-kv-bits capability before loading the model + if args.turbo_kv_bits is not None: + if args.kv_cache_quantization: + print( + "Error: --turbo-kv-bits and --kv-cache-quantization are mutually " + "exclusive; pick one compression path" + ) + sys.exit(1) + + from .memory_cache import _check_turboquant_capability + + missing = _check_turboquant_capability() + if missing is not None: + print(f"Error: --turbo-kv-bits requires TurboQuant support: {missing}") + sys.exit(1) + async def run_benchmark(): print(f"Loading model: {args.model}") model, tokenizer = load(args.model) @@ -297,6 +386,8 @@ async def run_benchmark(): kv_cache_quantization_bits=args.kv_cache_quantization_bits, kv_cache_quantization_group_size=args.kv_cache_quantization_group_size, kv_cache_min_quantize_tokens=args.kv_cache_min_quantize_tokens, + # TurboQuant + turbo_kv_bits=args.turbo_kv_bits, ) engine_config = EngineConfig( @@ -643,6 +734,72 @@ def bench_kv_cache_command(args): ) +def bench_serve_command(args): + """Run serving benchmark.""" + import asyncio + from .bench_serve import run_bench_serve + + prompt_sets = args.prompts.split(",") + concurrencies = [int(c) for c in args.concurrency.split(",")] + + # Parse thinking values + thinking_values = [None] + if args.enable_thinking: + thinking_values = [] + for v in args.enable_thinking.split(","): + v = v.strip().lower() + if v == "true": + thinking_values.append(True) + elif v == "false": + thinking_values.append(False) + + # Parse extra body (comma-separated JSON dicts) + extra_bodies = [""] + if args.extra_body: + # Handle both '{"a":1}','{"b":2}' and {"a":1},{"b":2} + import re + + extra_bodies = [ + s.strip().strip("'\"") + for s in re.split(r"(?<=})\s*,\s*(?={)", args.extra_body) + ] + + # Parse override fields + overrides = {} + for kv in args.override_field or []: + if "=" in kv: + k, v = kv.split("=", 1) + overrides[k] = v + + asyncio.run( + run_bench_serve( + url=args.url, + model=args.model, + prompt_sets=prompt_sets, + prompt_file=args.prompt_file, + concurrencies=concurrencies, + max_tokens=args.max_tokens, + repetitions=args.repetitions, + warmup=args.warmup, + thinking_values=thinking_values, + extra_bodies=extra_bodies, + output_path=args.output, + fmt=args.format, + do_validate=args.validate == "true", + scrape=args.scrape_metrics == "true", + tag=args.tag, + override_fields=overrides, + system_prompt_file=args.system_prompt_file, + # Auto-enable skip-preflight when a system-prompt-file is set: + # the whole point of that flag is measuring warm-cache behavior, + # and the preflight count_prompt_tokens request pollutes the cache. + skip_preflight_token_count=( + args.skip_preflight_token_count or bool(args.system_prompt_file) + ), + ) + ) + + def create_parser() -> argparse.ArgumentParser: """Build the top-level CLI parser.""" parser = argparse.ArgumentParser( @@ -666,7 +823,10 @@ def create_parser() -> argparse.ArgumentParser: help="The model name used in the API. If not specified, the model argument is used.", ) serve_parser.add_argument( - "--host", type=str, default="0.0.0.0", help="Host to bind" + "--host", + type=str, + default="127.0.0.1", + help="Host to bind (default: localhost; use 0.0.0.0 to expose externally)", ) serve_parser.add_argument("--port", type=int, default=8000, help="Port to bind") serve_parser.add_argument( @@ -744,6 +904,43 @@ def create_parser() -> argparse.ArgumentParser: default=256, help="Minimum tokens for quantization to apply (default: 256)", ) + # TurboQuant KV cache compression (arXiv 2504.19874) + serve_parser.add_argument( + "--turbo-kv-bits", + type=int, + default=None, + choices=[1, 2, 3, 4], + help="TurboQuant KV cache compression bits for prefix cache. " + "3-bit gives 4.6x compression vs FP16 (default: disabled). " + "Mutually exclusive with --kv-cache-quantization.", + ) + # SSD cache tiering options + serve_parser.add_argument( + "--ssd-cache-dir", + type=str, + default=None, + help="Directory for SSD KV cache tier (default: disabled)", + ) + serve_parser.add_argument( + "--ssd-cache-max-gb", + type=float, + default=10.0, + help="Maximum SSD cache size in GB (default: 10.0)", + ) + # Prompt warm-up options + serve_parser.add_argument( + "--warm-prompts", + type=str, + default=None, + help=( + "Path to a JSON file with prompts to pre-run at startup. Populates " + "the prefix cache so the first real request hits warm (cold TTFT " + "drops 1.3-2.3x on agent workloads). File format is a list of " + "message arrays, same shape as /v1/chat/completions messages. " + "Prompts are warmed concurrently -- keep the file small (1-3 entries " + "for typical agent deployments) to avoid memory pressure at boot." + ), + ) serve_parser.add_argument( "--stream-interval", type=int, @@ -756,6 +953,12 @@ def create_parser() -> argparse.ArgumentParser: default=32768, help="Default max tokens for generation (default: 32768)", ) + serve_parser.add_argument( + "--max-request-tokens", + type=int, + default=32768, + help="Maximum max_tokens accepted from API clients (default: 32768)", + ) serve_parser.add_argument( "--continuous-batching", action="store_true", @@ -885,6 +1088,29 @@ def create_parser() -> argparse.ArgumentParser: action="store_true", help="Expose Prometheus metrics on /metrics (disabled by default)", ) + serve_parser.add_argument( + "--auto-unload-idle-seconds", + type=float, + default=0.0, + help="Unload the main model after this many idle seconds (0 = disabled)", + ) + serve_parser.add_argument( + "--lazy-load-model", + action="store_true", + help="Register the main model at startup but defer loading until first request", + ) + serve_parser.add_argument( + "--max-audio-upload-mb", + type=int, + default=25, + help="Maximum size of uploaded audio files in MiB (default: 25)", + ) + serve_parser.add_argument( + "--max-tts-input-chars", + type=int, + default=4096, + help="Maximum number of characters accepted by /v1/audio/speech (default: 4096)", + ) # Tool calling options serve_parser.add_argument( "--enable-auto-tool-choice", @@ -943,6 +1169,11 @@ def create_parser() -> argparse.ArgumentParser: action="store_true", help="Force load model as multimodal (vision) even if name doesn't match auto-detection patterns", ) + serve_parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Allow HuggingFace remote code execution during model/tokenizer loading", + ) # Generation defaults serve_parser.add_argument( "--default-temperature", @@ -956,6 +1187,16 @@ def create_parser() -> argparse.ArgumentParser: default=None, help="Override default top_p for all requests (default: use model default)", ) + serve_parser.add_argument( + "--default-chat-template-kwargs", + type=make_json_object_arg_parser("--default-chat-template-kwargs"), + default=None, + help=( + "Default chat template kwargs to apply to all requests when request " + "chat_template_kwargs is omitted or empty; empty request kwargs use " + 'existing server defaults (JSON object, e.g. {"enable_thinking": true})' + ), + ) # Embedding model option serve_parser.add_argument( "--embedding-model", @@ -963,6 +1204,13 @@ def create_parser() -> argparse.ArgumentParser: default=None, help="Pre-load an embedding model at startup (e.g. mlx-community/embeddinggemma-300m-6bit)", ) + # Reranker model option + serve_parser.add_argument( + "--rerank-model", + type=str, + default=None, + help="Pre-load a reranker model at startup (e.g. mlx-community/jina-reranker-v2-base-multilingual)", + ) # Download options serve_parser.add_argument( "--download-timeout", @@ -1059,6 +1307,16 @@ def create_parser() -> argparse.ArgumentParser: default=256, help="Minimum tokens for quantization to apply (default: 256)", ) + # TurboQuant KV cache compression + bench_parser.add_argument( + "--turbo-kv-bits", + type=int, + default=None, + choices=[1, 2, 3, 4], + help="TurboQuant KV cache compression bits for prefix cache. " + "3-bit gives 4.6x compression vs FP16 (default: disabled). " + "Mutually exclusive with --kv-cache-quantization.", + ) # Paged cache options (experimental) bench_parser.add_argument( "--use-paged-cache", @@ -1139,9 +1397,140 @@ def create_parser() -> argparse.ArgumentParser: help="Download as multimodal model (broader file patterns)", ) + # Serving benchmark + bench_serve_parser = subparsers.add_parser( + "bench-serve", help="Benchmark a running vllm-mlx server via HTTP API" + ) + bench_serve_parser.add_argument( + "--url", + type=str, + default="http://127.0.0.1:8080", + help="Base URL of the running server (default: http://127.0.0.1:8080)", + ) + bench_serve_parser.add_argument( + "--model", + type=str, + default=None, + help="Model ID to benchmark (default: auto-detected from server)", + ) + bench_serve_parser.add_argument( + "--prompts", + type=str, + default="short,medium,long", + help="Comma-separated prompt set names or paths (default: short,medium,long)", + ) + bench_serve_parser.add_argument( + "--prompt-file", + type=str, + default=None, + help="Path to an additional prompt file (JSON list of message dicts)", + ) + bench_serve_parser.add_argument( + "--system-prompt-file", + type=str, + default=None, + help=( + "Path to a text file whose contents are prepended as a system " + "message to every prompt. Use this together with --warm-prompts " + "to benchmark the warm-cache path (the warmup populates the " + "prefix cache with this same system, so every request in the " + "bench hits the cache)." + ), + ) + bench_serve_parser.add_argument( + "--skip-preflight-token-count", + action="store_true", + help=( + "Skip the pre-flight max_tokens=1 request that counts prompt " + "tokens per prompt set. That request populates the prefix cache " + "with the full prompt, which defeats cold-vs-warm comparisons. " + "Auto-enabled when --system-prompt-file is set; pass this flag " + "explicitly to force-enable regardless." + ), + ) + bench_serve_parser.add_argument( + "--concurrency", + type=str, + default="1,4", + help="Comma-separated concurrency levels to sweep (default: 1,4)", + ) + bench_serve_parser.add_argument( + "--max-tokens", + type=int, + default=256, + help="Maximum tokens to generate per request (default: 256)", + ) + bench_serve_parser.add_argument( + "--repetitions", + type=int, + default=3, + help="Number of repetitions per sweep configuration (default: 3)", + ) + bench_serve_parser.add_argument( + "--warmup", + type=int, + default=1, + help="Warmup rounds before the first measured repetition (default: 1)", + ) + bench_serve_parser.add_argument( + "--enable-thinking", + type=str, + default=None, + help='Enable thinking mode: "true", "false", or "true,false" to sweep both', + ) + bench_serve_parser.add_argument( + "--extra-body", + type=str, + default=None, + help="Comma-separated JSON dicts to pass as extra body parameters", + ) + bench_serve_parser.add_argument( + "--output", + type=str, + default=None, + help="File path to write results to (default: stdout)", + ) + bench_serve_parser.add_argument( + "--format", + type=str, + default="table", + choices=["table", "json", "csv", "sql"], + help="Output format (default: table)", + ) + bench_serve_parser.add_argument( + "--validate", + type=str, + default="true", + choices=["true", "false"], + help="Validate responses (default: true)", + ) + bench_serve_parser.add_argument( + "--scrape-metrics", + type=str, + default="true", + choices=["true", "false"], + help="Scrape /metrics before and after each run (default: true)", + ) + bench_serve_parser.add_argument( + "--tag", + type=str, + default=None, + help="Optional tag string stored in every result row", + ) + bench_serve_parser.add_argument( + "--override-field", + nargs="*", + default=[], + help="Override result fields as key=value pairs (e.g. chip=M4Pro)", + ) + return parser +# Alias for test compatibility +build_parser = create_parser + + def main(): parser = create_parser() args = parser.parse_args() @@ -1156,6 +1545,8 @@ def main(): bench_kv_cache_command(args) elif args.command == "download": download_command(args) + elif args.command == "bench-serve": + bench_serve_command(args) else: parser.print_help() sys.exit(1) diff --git a/vllm_mlx/cli_arg_types.py b/vllm_mlx/cli_arg_types.py new file mode 100644 index 000000000..411e608f3 --- /dev/null +++ b/vllm_mlx/cli_arg_types.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Argparse type helpers shared by CLI entrypoints.""" + +import argparse +import json +from collections.abc import Callable +from typing import Any + + +def parse_json_object_arg(value: str, option_name: str) -> dict[str, Any]: + """Parse and validate that an option value is a JSON object.""" + try: + parsed = json.loads(value) + except json.JSONDecodeError as exc: + raise argparse.ArgumentTypeError( + f"{option_name} must be a valid JSON object: {exc.msg}" + ) from exc + + if not isinstance(parsed, dict): + raise argparse.ArgumentTypeError(f"{option_name} must be a JSON object") + + return parsed + + +def make_json_object_arg_parser(option_name: str) -> Callable[[str], dict[str, Any]]: + """Create an argparse type parser for JSON object options.""" + + def _parser(value: str) -> dict[str, Any]: + return parse_json_object_arg(value, option_name) + + return _parser diff --git a/vllm_mlx/constrained/__init__.py b/vllm_mlx/constrained/__init__.py new file mode 100644 index 000000000..758ab68ac --- /dev/null +++ b/vllm_mlx/constrained/__init__.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Constrained decoding for grammar-guided generation. + +Provides logits processors that mask token probabilities during generation +so the model can only emit sequences matching a target grammar (e.g. a JSON +schema). Used by the ``response_format`` parameter on the chat completion +and Anthropic Messages endpoints. +""" + +from .json_schema_processor import ( + JSONSchemaLogitsProcessor, + LMFormatEnforcerNotAvailableError, + is_available, +) + +__all__ = [ + "JSONSchemaLogitsProcessor", + "LMFormatEnforcerNotAvailableError", + "is_available", +] diff --git a/vllm_mlx/constrained/cache.py b/vllm_mlx/constrained/cache.py new file mode 100644 index 000000000..714caf611 --- /dev/null +++ b/vllm_mlx/constrained/cache.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Cache of ``TokenEnforcerTokenizerData`` objects keyed by tokenizer identity. + +Building ``TokenEnforcerTokenizerData`` requires iterating over the entire +vocabulary (up to 200k tokens on MiniMax/GLM) and decoding each token. The +cost is ~1-2 seconds per model and the result is independent of the JSON +schema, so we cache it for the lifetime of the process. +""" + +from __future__ import annotations + +import functools +import logging +import threading +from typing import Any + +logger = logging.getLogger(__name__) + +# Keyed by ``id(tokenizer)`` because tokenizer instances are not hashable +# across HF/MLX wrappers, but each server loads a single tokenizer per model. +_CACHE: dict[int, Any] = {} +_CACHE_LOCK = threading.Lock() + + +def _resolve_inner_tokenizer(tokenizer: Any) -> Any: + """ + VLM processors wrap the actual tokenizer under ``processor.tokenizer``. + ``mlx_lm.tokenizer_utils.TokenizerWrapper`` exposes it via ``_tokenizer``. + Return the most-unwrapped tokenizer that still has the HF + ``all_special_ids`` / ``eos_token_id`` surface. + + Note: on HF ``PreTrainedTokenizerFast``, ``_tokenizer`` points at the + rust-level object which lacks ``all_special_ids``; unwrapping to that + level would cause every special token (````, ````, ``\\n``, + ``<|think|>`` …) to leak into ``regular_tokens`` and end up in + ``TokenizerPrefixTree.root`` as an always-allowed token. We only + unwrap when the inner layer still exposes ``all_special_ids``. + """ + # VLM processor wrapper exposes the HF tokenizer under ``tokenizer``. + inner = getattr(tokenizer, "tokenizer", None) + if ( + inner is not None + and inner is not tokenizer + and hasattr(inner, "all_special_ids") + ): + tokenizer = inner + # mlx_lm TokenizerWrapper keeps the raw HF tokenizer under ``_tokenizer``. + # Only unwrap if the inner object still exposes the HF tokenizer surface. + inner = getattr(tokenizer, "_tokenizer", None) + if inner is not None and hasattr(inner, "all_special_ids"): + tokenizer = inner + return tokenizer + + +def _build_regular_tokens_list( + tokenizer: Any, vocab_size: int +) -> list[tuple[int, str, bool]]: + """ + Enumerate the regular (non-special) tokens in the vocabulary and produce + the ``(token_id, decoded_with_leading_space_marker, is_word_start)`` tuples + required by ``TokenEnforcerTokenizerData``. + + Mirrors the reference implementation in ``lmformatenforcer.integrations. + transformers`` but works with the HF tokenizer surface only (so we do not + need a hard transformers dependency at the right version). + """ + try: + special_ids = set(tokenizer.all_special_ids) + except AttributeError: + special_ids = set() + + try: + token_0 = tokenizer.encode("0")[-1] + except Exception: + token_0 = None + + regular_tokens: list[tuple[int, str, bool]] = [] + for token_idx in range(vocab_size): + if token_idx in special_ids: + continue + try: + decoded_regular = tokenizer.decode([token_idx]) + except Exception: + continue + if token_0 is not None: + try: + decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:] + except Exception: + decoded_after_0 = decoded_regular + else: + decoded_after_0 = decoded_regular + is_word_start_token = len(decoded_after_0) > len(decoded_regular) + regular_tokens.append((token_idx, decoded_after_0, is_word_start_token)) + return regular_tokens + + +def _get_eos_token_id(tokenizer: Any) -> int | list[int]: + # Some tokenizers expose multiple EOS candidates (e.g. Gemma 4 has + # [1 , 106 , 50 <|think|>] in generation_config.json). + # Prefer the list form so all stop tokens are treated as EOS by the + # enforcer; otherwise the model may emit an out-of-schema stop token + # that the enforcer did not mask (because it's a special token not in + # ``regular_tokens``), yet the inference runtime still treats as stop. + eos_list = getattr(tokenizer, "eos_token_ids", None) + if isinstance(eos_list, (list, tuple)) and eos_list: + return list(eos_list) + eos = getattr(tokenizer, "eos_token_id", None) + if eos is not None: + return eos + return 0 + + +def _get_vocab_size(tokenizer: Any) -> int: + vs = getattr(tokenizer, "vocab_size", None) + if isinstance(vs, int) and vs > 0: + return vs + try: + return len(tokenizer) + except TypeError: + pass + get_vocab = getattr(tokenizer, "get_vocab", None) + if callable(get_vocab): + return len(get_vocab()) + raise ValueError("Cannot determine tokenizer vocab size") + + +def _decode_function(tokenizer: Any, tokens: list[int]) -> str: + try: + decoded = tokenizer.decode(tokens) + except Exception: + return "" + return decoded.rstrip("\ufffd") if isinstance(decoded, str) else "" + + +def get_tokenizer_data(tokenizer: Any) -> Any | None: + """ + Return a cached ``TokenEnforcerTokenizerData`` for ``tokenizer``. + + Returns ``None`` if ``lm-format-enforcer`` is not installed or the + tokenizer cannot be adapted. + """ + try: + from lmformatenforcer.tokenenforcer import TokenEnforcerTokenizerData + except ImportError: + return None + + inner = _resolve_inner_tokenizer(tokenizer) + key = id(inner) + with _CACHE_LOCK: + cached = _CACHE.get(key) + if cached is not None: + return cached + + try: + vocab_size = _get_vocab_size(inner) + except Exception as exc: + logger.warning( + "Could not determine vocab size for constrained decoding: %s", exc + ) + return None + + try: + regular_tokens = _build_regular_tokens_list(inner, vocab_size) + decode_fn = functools.partial(_decode_function, inner) + eos_token_id = _get_eos_token_id(inner) + data = TokenEnforcerTokenizerData( + regular_tokens, + decode_fn, + eos_token_id, + False, # use_bitmask + vocab_size, + ) + except Exception as exc: + logger.warning("Failed to build TokenEnforcerTokenizerData: %s", exc) + return None + + with _CACHE_LOCK: + _CACHE[key] = data + return data + + +def clear_cache() -> None: + """Drop the cache (mainly for tests).""" + with _CACHE_LOCK: + _CACHE.clear() diff --git a/vllm_mlx/constrained/json_schema_processor.py b/vllm_mlx/constrained/json_schema_processor.py new file mode 100644 index 000000000..9e5393319 --- /dev/null +++ b/vllm_mlx/constrained/json_schema_processor.py @@ -0,0 +1,755 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +``JSONSchemaLogitsProcessor`` — a ``mlx_lm``-compatible logits processor that +masks the vocabulary so the model can only emit tokens forming a valid JSON +value (optionally matching a JSON schema). + +The processor implements the signature expected by ``mlx_lm.generate.generate_step`` +and ``vllm_mlx``'s batched engine alike: + + processor(tokens: mx.array, logits: mx.array) -> mx.array + +``tokens`` contains the full sequence generated for this request so far +(prompt + previously emitted tokens), and ``logits`` is the last-step logits +row. +""" + +from __future__ import annotations + +import copy +import json +import logging +from typing import Any + +import mlx.core as mx +import numpy as np + +from .cache import get_tokenizer_data + +logger = logging.getLogger(__name__) + + +class LMFormatEnforcerNotAvailableError(RuntimeError): + """Raised when ``lm-format-enforcer`` is required but not installed.""" + + +def is_available() -> bool: + """Return ``True`` iff ``lm-format-enforcer`` is importable.""" + try: + import lmformatenforcer # noqa: F401 + except ImportError: + return False + return True + + +# A permissive "any JSON value" schema for ``response_format.type = json_object``. +# JSON spec allows any of {object, array, string, number, boolean, null} at the +# top level, but realistic OpenAI ``json_object`` mode expects an object or +# array at the root. +_GENERIC_JSON_SCHEMA: dict = { + "anyOf": [ + {"type": "object"}, + {"type": "array"}, + ] +} + + +def _simplify_schema(schema: dict) -> dict: + """Pre-process a JSON Schema for ``lm-format-enforcer`` compatibility. + + ``lm-format-enforcer`` does not support ``$ref``, ``not``, ``type`` as an + array, or recursive definitions. This function: + + 1. Resolves ``$ref`` by inlining referenced definitions (with cycle + detection so recursive definitions are truncated to ``{}``). + 2. Removes ``not`` sub-schemas (makes the schema more permissive). + 3. Strips metadata / serialisation-hint keywords that the enforcer does + not understand: ``default``, ``examples``, ``title``, ``description``, + ``$schema``, ``$id``. + 4. Converts ``type: [t1, t2, ...]`` to ``anyOf: [{type: t1}, ...]``. + 5. Cleans up empty ``anyOf`` / ``oneOf`` branches. + 6. Flattens nested ``anyOf``/``oneOf`` (e.g. + ``anyOf: [{anyOf: [A, B]}, C]`` → ``anyOf: [A, B, C]``). + """ + schema = copy.deepcopy(schema) + definitions: dict = {} + definitions.update(schema.pop("definitions", {})) + definitions.update(schema.pop("$defs", {})) + + resolving: set[str] = set() # cycle guard + + def _resolve(node: Any, depth: int = 0) -> Any: + if depth > 12 or not isinstance(node, dict): + return node + + # --- resolve $ref -------------------------------------------------- + if "$ref" in node: + ref: str = node["$ref"] + parts = ref.split("/") + if ( + len(parts) == 3 + and parts[0] == "#" + and parts[1] in ("definitions", "$defs") + ): + name = parts[2] + if name in definitions and ref not in resolving: + resolving.add(ref) + resolved = copy.deepcopy(definitions[name]) + # Merge extra keys (e.g. ``default``) from the $ref node. + for k, v in node.items(): + if k != "$ref" and k not in resolved: + resolved[k] = v + result = _resolve(resolved, depth + 1) + resolving.discard(ref) + return result + # Circular or unresolvable — return empty (= any). + return {} + + # --- remove unsupported keywords ----------------------------------- + node.pop("not", None) + node.pop("$schema", None) + node.pop("$id", None) + # Metadata / serialisation hints that lm-format-enforcer doesn't + # understand — keeping them causes the parser to mis-navigate. + node.pop("default", None) + node.pop("examples", None) + node.pop("title", None) + node.pop("description", None) + + # --- type array → anyOf -------------------------------------------- + if isinstance(node.get("type"), list): + types = node.pop("type") + items_schema = node.pop("items", None) + branches: list[dict] = [] + for t in types: + branch: dict[str, Any] = {"type": t} + if t == "array" and items_schema is not None: + branch["items"] = _resolve(copy.deepcopy(items_schema), depth + 1) + branches.append(branch) + existing = node.pop("anyOf", []) + node["anyOf"] = existing + branches + + # --- recurse into sub-schemas -------------------------------------- + if "properties" in node and isinstance(node["properties"], dict): + for k in list(node["properties"]): + node["properties"][k] = _resolve(node["properties"][k], depth + 1) + + for key in ("items", "additionalProperties"): + if key in node and isinstance(node[key], dict): + node[key] = _resolve(node[key], depth + 1) + + for key in ("allOf", "anyOf", "oneOf"): + if key in node and isinstance(node[key], list): + # Resolve each branch; drop empty dicts (= "any", redundant + # inside anyOf since they make the whole constraint trivially + # true — but keeping one "any" branch confuses the enforcer). + resolved_items = [_resolve(item, depth + 1) for item in node[key]] + node[key] = [it for it in resolved_items if it != {}] + if not node[key]: + del node[key] + + # --- flatten nested anyOf/oneOf ------------------------------------ + # ``anyOf: [{anyOf: [A, B]}, C]`` → ``anyOf: [A, B, C]`` when the + # wrapper dict has no extra keys. This removes one level of nesting + # that confuses lm-format-enforcer's UnionParser. + for key in ("anyOf", "oneOf"): + if key in node and isinstance(node[key], list): + flattened: list[Any] = [] + for item in node[key]: + if isinstance(item, dict) and key in item and len(item) == 1: + flattened.extend(item[key]) + else: + flattened.append(item) + node[key] = flattened + + return node + + return _resolve(schema) + + +def _force_no_additional_properties(schema: dict) -> dict: + """Return a deep copy of *schema* with ``additionalProperties: false`` + injected into every object-type sub-schema that declares ``properties``. + + ``lm-format-enforcer`` has a bug where multi-character tokens spanning + JSON structural boundaries (e.g., a single token that decodes to ``""``) + can produce empty or whitespace-only keys, causing ``KeyError`` crashes in + ``jsonschemaparser.py``. Setting ``additionalProperties: false`` tells the + enforcer's trie traversal that only the declared property names are valid + keys, which significantly narrows the allowed tokens and prevents most of + these boundary-spanning issues. + """ + schema = copy.deepcopy(schema) + _inject_no_additional_props(schema) + return schema + + +def _inject_no_additional_props(node: Any) -> None: + """Recursively inject ``additionalProperties: false`` into *node*.""" + if not isinstance(node, dict): + return + if "properties" in node and "additionalProperties" not in node: + node["additionalProperties"] = False + for value in node.values(): + if isinstance(value, dict): + _inject_no_additional_props(value) + elif isinstance(value, list): + for item in value: + _inject_no_additional_props(item) + + +def _collect_property_names(schema: dict | None) -> set[str]: + """Collect all property names declared anywhere in *schema*.""" + names: set[str] = set() + if schema is None: + return names + _walk_properties(schema, names) + return names + + +def _walk_properties(node: Any, names: set[str]) -> None: + if not isinstance(node, dict): + return + props = node.get("properties") + if isinstance(props, dict): + names.update(props.keys()) + for v in props.values(): + _walk_properties(v, names) + for key in ("items", "additionalProperties", "not"): + if key in node and isinstance(node[key], dict): + _walk_properties(node[key], names) + for key in ("allOf", "anyOf", "oneOf"): + if key in node and isinstance(node[key], list): + for item in node[key]: + _walk_properties(item, names) + + +class JSONSchemaLogitsProcessor: + """ + Logits processor that constrains generation to valid JSON. + + Parameters + ---------- + schema: + The JSON Schema the output must match. When ``None``, any valid JSON + object/array is accepted (``json_object`` mode). + tokenizer: + The tokenizer used for generation. Its vocabulary is iterated once + (via :mod:`vllm_mlx.constrained.cache`) and cached for subsequent + requests. + """ + + def __init__( + self, + schema: dict | None, + tokenizer: Any, + ) -> None: + if not is_available(): + raise LMFormatEnforcerNotAvailableError( + "lm-format-enforcer is not installed. " + 'Install it with `pip install "lm-format-enforcer>=0.10.9"`.' + ) + + from lmformatenforcer import JsonSchemaParser, TokenEnforcer + + self._tokenizer = tokenizer + self._schema = schema + self._tok_data = get_tokenizer_data(tokenizer) + if self._tok_data is None: + raise LMFormatEnforcerNotAvailableError( + "Could not build TokenEnforcerTokenizerData for this tokenizer." + ) + + # Pre-process schema: resolve $ref, remove 'not', convert type arrays. + # Then harden: force ``additionalProperties: false`` on all object + # sub-schemas to prevent the enforcer from allowing arbitrary keys. + self._disabled = False + if schema is not None: + parser_schema = _simplify_schema(schema) + parser_schema = _force_no_additional_properties(parser_schema) + else: + parser_schema = _GENERIC_JSON_SCHEMA + + try: + self._parser = JsonSchemaParser(parser_schema) + self._enforcer = TokenEnforcer(self._tok_data, self._parser) + except Exception as exc: + logger.warning( + "JSONSchemaLogitsProcessor: enforcer init failed (%s); " + "falling back to unconstrained generation", + exc, + ) + self._disabled = True + self._parser = None # type: ignore[assignment] + self._enforcer = None # type: ignore[assignment] + + # Bootstrap the enforcer's ``prefix_states`` with the empty tuple so + # that subsequent ``get_allowed_tokens([t1, t2, ...])`` calls can find + # their ``prev_step_tuple`` and apply characters incrementally rather + # than treating the whole sequence as a prompt and resetting to the + # root parser. + if not self._disabled: + try: + self._enforcer.get_allowed_tokens([]) + except Exception as exc: + logger.warning( + "TokenEnforcer bootstrap failed (%s); " + "falling back to unconstrained generation", + exc, + ) + self._disabled = True + + self._prompt_len: int | None = None + self._vocab_size: int = self._tok_data.vocab_size + + # EOS/stop tokens cache. + eos_id = getattr(self._tok_data, "eos_token_id", None) + if isinstance(eos_id, (list, tuple, set)): + self._eos_set: set[int] = {int(e) for e in eos_id} + elif eos_id is not None: + self._eos_set = {int(eos_id)} + else: + self._eos_set = set() + + # Pre-compute valid property name prefixes for key-start filtering. + all_names = _collect_property_names(schema) + self._valid_key_first_chars: set[str] = {n[0] for n in all_names if n} + self._valid_key_names: set[str] = all_names + + # Lazy decode cache — populated on demand. + self._token_decode_cache: dict[int, str | None] = {} + + # Suffix decode cache keyed by length. Full tokenizer.decode() + # is always used (incremental per-token decode is incorrect for + # BPE/SentencePiece tokenizers where whitespace is a token prefix). + self._cached_suffix_text: str = "" + self._cached_suffix_len: int = 0 + + # Incremental JSON context state — avoids re-scanning the full + # decoded text on every step. + self._json_ctx_in_string: bool = False + self._json_ctx_last_quote_pos: int = -1 + self._json_ctx_scanned_len: int = 0 + + # Bracket/brace depth counters for fast _suffix_is_complete_json + # pre-check. Updated by _get_json_context incrementally. JSON + # is only potentially complete when both are zero. + self._brace_depth: int = 0 + self._bracket_depth: int = 0 + + # ------------------------------------------------------------------ + + def _suffix(self, tokens_list: list[int]) -> list[int]: + """Return the slice of ``tokens`` that corresponds to generated output.""" + if self._prompt_len is None: + self._prompt_len = max(0, len(tokens_list) - 1) + return tokens_list[self._prompt_len :] + + def _decode_token_cached(self, tok_id: int) -> str | None: + """Return the decoded text for a single token (cached).""" + cached = self._token_decode_cache.get(tok_id) + if cached is not None: + return cached + if tok_id in self._token_decode_cache: + return None # previously cached as None + try: + decoded = self._tokenizer.decode([tok_id]) + except Exception: + self._token_decode_cache[tok_id] = None + return None + result = decoded if isinstance(decoded, str) else None + self._token_decode_cache[tok_id] = result + return result + + def _decode_suffix(self, suffix: list[int]) -> str | None: + """Decode suffix tokens to text. + + Always uses full ``tokenizer.decode(suffix)`` which is correct for + all tokenizer families (BPE, SentencePiece, etc.). Per-token + concatenation is NOT safe because whitespace may be encoded as a + token prefix (e.g. ``decode([1526]) = "world"`` but in context + ``decode([22557, 1526]) = "Hello world"``). + + Results are cached by suffix length to avoid redundant decodes + within the same generation step (``_get_json_context`` and + ``_suffix_is_complete_json`` both call this method). + """ + if not suffix: + self._cached_suffix_text = "" + self._cached_suffix_len = 0 + return "" + + suffix_len = len(suffix) + + # Fast path: already decoded this exact suffix length. + if suffix_len == self._cached_suffix_len: + return self._cached_suffix_text + + # Full decode (correct for all tokenizer families). + try: + decoded = self._tokenizer.decode(list(suffix)) + except Exception: + return None + result = decoded if isinstance(decoded, str) else "" + + # Validate prefix stability for incremental JSON context scanning. + # decode(tokens[:n]) must be a prefix of decode(tokens[:n+1]) for + # the incremental scanner in _get_json_context to be correct. + if ( + suffix_len > self._cached_suffix_len + and self._cached_suffix_len > 0 + and not result.startswith(self._cached_suffix_text) + ): + # Prefix changed — reset incremental context state. + self._json_ctx_scanned_len = 0 + self._json_ctx_in_string = False + self._json_ctx_last_quote_pos = -1 + self._brace_depth = 0 + self._bracket_depth = 0 + + self._cached_suffix_text = result + self._cached_suffix_len = suffix_len + return result + + def _suffix_is_complete_json(self, suffix: list[int]) -> bool: + """Return True if the decoded ``suffix`` parses as a complete JSON value. + + Uses cached bracket/brace depth from ``_get_json_context`` as a + fast pre-check: JSON cannot be complete when brackets are + unbalanced or we are inside a string. This avoids the expensive + ``json.loads`` call on ~99% of steps. + """ + if not suffix: + return False + # Fast pre-check using cached structural state. + if self._brace_depth != 0 or self._bracket_depth != 0: + return False + if self._json_ctx_in_string: + return False + text = self._decode_suffix(suffix) + if not text: + return False + text = text.strip() + if not text: + return False + try: + json.loads(text) + except (ValueError, json.JSONDecodeError): + return False + return True + + def _get_json_context(self, suffix: list[int]) -> str: + """Determine the JSON structural context of the current suffix. + + Processes only newly appended characters instead of re-scanning + the full decoded text on every call (O(1) amortised per step + instead of O(n)). + + Returns one of: + - ``"key_start"``: expecting a new key (after ``{`` or ``,``) + - ``"in_key"``: inside an open key string + - ``"other"``: any other position + """ + text = self._decode_suffix(suffix) + if text is None or not text: + return "other" + + text_len = len(text) + + if text_len > self._json_ctx_scanned_len and self._json_ctx_scanned_len > 0: + # Incremental scan: process only new characters. + in_string = self._json_ctx_in_string + last_quote_pos = self._json_ctx_last_quote_pos + brace_depth = self._brace_depth + bracket_depth = self._bracket_depth + i = self._json_ctx_scanned_len + while i < text_len: + ch = text[i] + if in_string: + if ch == "\\" and i + 1 < text_len: + i += 2 + continue + if ch == '"': + in_string = False + else: + if ch == '"': + in_string = True + last_quote_pos = i + elif ch == "{": + brace_depth += 1 + elif ch == "}": + brace_depth -= 1 + elif ch == "[": + bracket_depth += 1 + elif ch == "]": + bracket_depth -= 1 + i += 1 + self._json_ctx_in_string = in_string + self._json_ctx_last_quote_pos = last_quote_pos + self._json_ctx_scanned_len = text_len + self._brace_depth = brace_depth + self._bracket_depth = bracket_depth + else: + # Full scan (first call or text shrank/reset). + in_string = False + last_quote_pos = -1 + brace_depth = 0 + bracket_depth = 0 + i = 0 + while i < text_len: + ch = text[i] + if in_string: + if ch == "\\" and i + 1 < text_len: + i += 2 + continue + if ch == '"': + in_string = False + else: + if ch == '"': + in_string = True + last_quote_pos = i + elif ch == "{": + brace_depth += 1 + elif ch == "}": + brace_depth -= 1 + elif ch == "[": + bracket_depth += 1 + elif ch == "]": + bracket_depth -= 1 + i += 1 + self._json_ctx_in_string = in_string + self._json_ctx_last_quote_pos = last_quote_pos + self._json_ctx_scanned_len = text_len + self._brace_depth = brace_depth + self._bracket_depth = bracket_depth + + if self._json_ctx_in_string: + before = text[: self._json_ctx_last_quote_pos].rstrip() + if not before or before[-1] in ("{", ","): + return "in_key" + return "other" + + stripped = text.rstrip() + if not stripped: + return "other" + if stripped[-1] in ("{", ","): + return "key_start" + return "other" + + def _filter_at_key_context( + self, context: str, suffix: list[int], allowed: list[int] + ) -> list[int]: + """Apply schema-aware filtering when in key-related context. + + At ``key_start``: only allow tokens that begin a valid key, whitespace, + ``}``, or just ``"``. + At ``in_key``: only allow tokens compatible with continuing a valid + property name (no leading whitespace; content must be a valid prefix). + """ + if not self._valid_key_names: + return allowed # no schema info → skip filtering + + if context == "key_start": + return self._filter_key_start_tokens(suffix, allowed) + elif context == "in_key": + return self._filter_in_key_tokens(suffix, allowed) + return allowed + + def _filter_key_start_tokens( + self, suffix: list[int], allowed: list[int] + ) -> list[int]: + """Filter tokens at key-start position. + + Only permit tokens that: + - Are whitespace-only (before the key ``"``) + - Decode to ``}`` (close object) + - Start a valid key: ``"`` followed by a valid first char + """ + result = [] + for tok_id in allowed: + if tok_id in self._eos_set: + continue # EOS at key-start is handled separately + tok_text = self._decode_token_cached(tok_id) + if tok_text is None: + result.append(tok_id) + continue + stripped = tok_text.lstrip() + if not stripped: + # Pure whitespace — allowed before key + result.append(tok_id) + continue + if stripped[0] == "}": + # Closing brace — end of object + result.append(tok_id) + continue + if stripped[0] == '"': + # Opening a key — validate content + rest = stripped[1:] + if not rest: + # Just ``"`` — will be validated on next step + result.append(tok_id) + continue + # Check if rest starts with a valid key character + if rest[0] in self._valid_key_first_chars: + # Further check: does the key content (up to closing ``"``) + # match a prefix of a known property name? + close_idx = rest.find('"') + if close_idx < 0: + # Key not yet closed — check prefix + if self._is_valid_key_prefix(rest): + result.append(tok_id) + else: + # Key fully contained in this token + key_name = rest[:close_idx] + if key_name in self._valid_key_names: + result.append(tok_id) + continue + # First char not in valid set → skip + continue + # Other chars (digits, letters without quote) → skip at key-start + continue + return result if result else allowed # safety: never return empty + + def _filter_in_key_tokens(self, suffix: list[int], allowed: list[int]) -> list[int]: + """Filter tokens when we're inside an open key string. + + Only allow tokens whose content continues a valid property name. + Reject whitespace-only/leading-whitespace tokens. + """ + # Figure out what key content we've accumulated so far. + text = self._decode_suffix(suffix) + if text is None: + return allowed + + # Find the last unmatched ``"`` — everything after it is key content + # accumulated so far. + last_open = text.rfind('"') + if last_open < 0: + return allowed + key_so_far = text[last_open + 1 :] + + result = [] + for tok_id in allowed: + tok_text = self._decode_token_cached(tok_id) + if tok_text is None: + result.append(tok_id) + continue + # Token must not start with whitespace (no ws inside keys) + if tok_text and tok_text[0] in (" ", "\t", "\n", "\r"): + continue + # Check if key_so_far + tok_text is a valid key prefix + candidate = key_so_far + tok_text + # If the closing ``"`` is in tok_text, extract the full key + close_idx = tok_text.find('"') + if close_idx >= 0: + full_key = key_so_far + tok_text[:close_idx] + if full_key in self._valid_key_names: + result.append(tok_id) + else: + # Key still open — check if it's a valid prefix + if self._is_valid_key_prefix(candidate): + result.append(tok_id) + return result if result else allowed + + def _is_valid_key_prefix(self, prefix: str) -> bool: + """Return True if *prefix* is a prefix of at least one valid key name.""" + return any(name.startswith(prefix) for name in self._valid_key_names) + + def _build_allow_mask(self, allowed: list[int], vocab_size: int) -> mx.array: + """ + Build a 1-D mask of length ``vocab_size`` where allowed positions are + ``0`` and disallowed positions are ``-inf``. + + Uses numpy for mask construction (C-level speed) instead of a + Python loop over ``vocab_size`` elements. + """ + if not allowed: + return mx.full((vocab_size,), -float("inf")) + allowed_clamped = [i for i in allowed if 0 <= i < vocab_size] + if not allowed_clamped: + return mx.full((vocab_size,), -float("inf")) + buf = np.full(vocab_size, -np.inf, dtype=np.float32) + buf[allowed_clamped] = 0.0 + return mx.array(buf) + + # ------------------------------------------------------------------ + + def __call__(self, tokens: mx.array, logits: mx.array) -> mx.array: + """Apply the allowed-tokens mask to ``logits``.""" + if self._disabled: + return logits + + try: + tokens_list = tokens.tolist() if hasattr(tokens, "tolist") else list(tokens) + if isinstance(tokens_list, int): + tokens_list = [tokens_list] + elif tokens_list and isinstance(tokens_list[0], list): + tokens_list = tokens_list[0] + + suffix = self._suffix(tokens_list) + # Use prompt_len directly instead of O(n) list comparison. + pass_to_enforcer = suffix if self._prompt_len else tokens_list + allowed_result = self._enforcer.get_allowed_tokens(pass_to_enforcer) + allowed = getattr(allowed_result, "allowed_tokens", allowed_result) + if allowed is None: + return logits + + allowed_list = list(allowed) + + # --- Schema-aware key filter (before EOS guard so that the + # incremental JSON context state and bracket depth counters + # are up-to-date for the _suffix_is_complete_json pre-check). + context = self._get_json_context(suffix) + if context in ("key_start", "in_key"): + allowed_list = self._filter_at_key_context( + context, suffix, allowed_list + ) + + # --- EOS guard: only permit EOS when output is valid JSON --- + if ( + self._eos_set + and any(t in self._eos_set for t in allowed_list) + and not self._suffix_is_complete_json(suffix) + ): + allowed_list = [t for t in allowed_list if t not in self._eos_set] + + # --- Recovery: if enforcer returns empty set AND output is not + # complete JSON, the schema is likely unsupported — disable the + # processor and let the model generate freely (system prompt + + # post-validation still apply). Only force EOS if the output + # already parses as valid JSON (generation is done). + if not allowed_list: + if self._suffix_is_complete_json(suffix) and self._eos_set: + allowed_list = sorted(self._eos_set) + else: + logger.warning( + "JSONLP: enforcer stuck (empty allowed-set at " + "suffix_len=%d); disabling constrained decoding " + "for this request", + len(suffix), + ) + self._disabled = True + return logits + + actual_vocab = logits.shape[-1] + mask = self._build_allow_mask(allowed_list, actual_vocab) + if logits.ndim == 2 and logits.shape[0] == 1: + mask = mask[None, :] + return logits + mask + except Exception as exc: # pragma: no cover - defensive + logger.error( + "JSONSchemaLogitsProcessor crashed; disabling for this request: %s", + exc, + ) + self._disabled = True + return logits + + # Diagnostic helpers ------------------------------------------------- + + @property + def schema(self) -> dict | None: + return self._schema + + @property + def vocab_size(self) -> int: + return self._vocab_size diff --git a/vllm_mlx/endpoint_model_policies.py b/vllm_mlx/endpoint_model_policies.py new file mode 100644 index 000000000..6fdb9eada --- /dev/null +++ b/vllm_mlx/endpoint_model_policies.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Request-time model resolution policies for optional endpoints. + +These endpoints intentionally do not expose arbitrary Hugging Face loading from +user-controlled request bodies. Unknown model names must be rejected before any +engine instantiation or download path is reached. +""" + +from fastapi import HTTPException + +_EMBEDDING_MODELS = frozenset( + { + "mlx-community/ModernBERT-base-mlx", + "mlx-community/all-MiniLM-L6-v2-4bit", + "mlx-community/bert-base-uncased-mlx", + "mlx-community/bge-large-en-v1.5-4bit", + "mlx-community/embeddinggemma-300m-6bit", + "mlx-community/multilingual-e5-large-mlx", + "mlx-community/multilingual-e5-small-mlx", + } +) + +_STT_MODEL_ALIASES = { + "whisper-large-v3": "mlx-community/whisper-large-v3-mlx", + "whisper-large-v3-turbo": "mlx-community/whisper-large-v3-turbo", + "whisper-medium": "mlx-community/whisper-medium-mlx", + "whisper-small": "mlx-community/whisper-small-mlx", + "parakeet": "mlx-community/parakeet-tdt-0.6b-v2", + "parakeet-v3": "mlx-community/parakeet-tdt-0.6b-v3", +} + +_TTS_MODEL_ALIASES = { + "kokoro": "mlx-community/Kokoro-82M-bf16", + "kokoro-4bit": "mlx-community/Kokoro-82M-4bit", + "chatterbox": "mlx-community/chatterbox-turbo-fp16", + "chatterbox-4bit": "mlx-community/chatterbox-turbo-4bit", + "vibevoice": "mlx-community/VibeVoice-Realtime-0.5B-4bit", + "voxcpm": "mlx-community/VoxCPM1.5", +} + + +def _with_identity_aliases(model_map: dict[str, str]) -> dict[str, str]: + expanded = dict(model_map) + for model_name in model_map.values(): + expanded[model_name] = model_name + return expanded + + +_STT_MODEL_MAP = _with_identity_aliases(_STT_MODEL_ALIASES) +_TTS_MODEL_MAP = _with_identity_aliases(_TTS_MODEL_ALIASES) + + +def _reject_unknown_embedding_model(requested_model: str) -> None: + supported = ", ".join(sorted(_EMBEDDING_MODELS)) + raise HTTPException( + status_code=400, + detail=( + f"Embedding model '{requested_model}' is not available. " + "Request-time embedding model loading is limited to the supported " + f"allowlist: {supported}. To use a different embedding model, start " + "the server with --embedding-model ." + ), + ) + + +def _reject_unknown_audio_model( + endpoint: str, + requested_model: str, + supported_aliases: dict[str, str], +) -> None: + aliases = ", ".join(sorted(supported_aliases)) + raise HTTPException( + status_code=400, + detail=( + f"{endpoint} model '{requested_model}' is not available. " + f"Supported request models are: {aliases}. Exact configured model IDs " + "for those aliases are also accepted." + ), + ) + + +def resolve_embedding_model_name( + requested_model: str, + *, + locked_model: str | None = None, +) -> str: + """Resolve the embedding model for a request or raise HTTP 400.""" + if locked_model is not None: + if requested_model == locked_model: + return locked_model + raise HTTPException( + status_code=400, + detail=( + f"Embedding model '{requested_model}' is not available. " + f"This server was started with --embedding-model {locked_model}. " + f"Only '{locked_model}' can be used for embeddings. Restart the " + f"server with a different --embedding-model to use '{requested_model}'." + ), + ) + + if requested_model in _EMBEDDING_MODELS: + return requested_model + + _reject_unknown_embedding_model(requested_model) + + +def resolve_stt_model_name(requested_model: str) -> str: + """Resolve an STT request model alias or configured model ID.""" + if requested_model in _STT_MODEL_MAP: + return _STT_MODEL_MAP[requested_model] + _reject_unknown_audio_model("Transcription", requested_model, _STT_MODEL_ALIASES) + + +def resolve_tts_model_name(requested_model: str) -> str: + """Resolve a TTS request model alias or configured model ID.""" + if requested_model in _TTS_MODEL_MAP: + return _TTS_MODEL_MAP[requested_model] + _reject_unknown_audio_model("Speech", requested_model, _TTS_MODEL_ALIASES) diff --git a/vllm_mlx/engine/__init__.py b/vllm_mlx/engine/__init__.py index f6625abd0..c10934105 100644 --- a/vllm_mlx/engine/__init__.py +++ b/vllm_mlx/engine/__init__.py @@ -2,27 +2,53 @@ """ Engine abstraction for vllm-mlx inference. -Provides two engine implementations: -- SimpleEngine: Direct model calls for maximum single-user throughput -- BatchedEngine: Continuous batching for multiple concurrent users - -Also re-exports core engine components for backwards compatibility. +The package stays intentionally light at import time so server- and +contract-level tests can import API modules without eagerly importing MLX, +engine_core, or the batched engine stack. """ +from __future__ import annotations + +_ENGINE_CORE_NAMES = frozenset({"EngineCore", "AsyncEngineCore", "EngineConfig"}) +from typing import TYPE_CHECKING + from .base import BaseEngine, GenerationOutput -from .simple import SimpleEngine -from .batched import BatchedEngine -# Re-export from parent engine.py for backwards compatibility -from ..engine_core import EngineCore, AsyncEngineCore, EngineConfig +if TYPE_CHECKING: + from ..engine_core import AsyncEngineCore, EngineConfig, EngineCore + from .batched import BatchedEngine + from .simple import SimpleEngine __all__ = [ "BaseEngine", "GenerationOutput", "SimpleEngine", "BatchedEngine", - # Core engine components + # Core engine components (lazy) "EngineCore", "AsyncEngineCore", "EngineConfig", ] + + +def __getattr__(name: str): + if name == "SimpleEngine": + from .simple import SimpleEngine + + return SimpleEngine + + if name == "BatchedEngine": + from .batched import BatchedEngine + + return BatchedEngine + + if name in {"EngineCore", "AsyncEngineCore", "EngineConfig"}: + from ..engine_core import AsyncEngineCore, EngineConfig, EngineCore + + return { + "EngineCore": EngineCore, + "AsyncEngineCore": AsyncEngineCore, + "EngineConfig": EngineConfig, + }[name] + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm_mlx/engine/base.py b/vllm_mlx/engine/base.py index fc5e70454..6970bb742 100644 --- a/vllm_mlx/engine/base.py +++ b/vllm_mlx/engine/base.py @@ -3,11 +3,16 @@ Base engine interface for vllm-mlx inference. """ +import asyncio +import logging from abc import ABC, abstractmethod -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Awaitable, Callable +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Any +logger = logging.getLogger(__name__) + @dataclass class GenerationOutput: @@ -27,6 +32,61 @@ class GenerationOutput: finished: bool = True +@contextmanager +def suspend_cancellation(): + """Temporarily clear task cancellation so cleanup can finish deterministically.""" + task = asyncio.current_task() + if task is None: + yield + return + + cancelling = getattr(task, "cancelling", None) + uncancel = getattr(task, "uncancel", None) + if cancelling is None or uncancel is None: + yield + return + + pending_cancels = cancelling() + for _ in range(pending_cancels): + uncancel() + try: + yield + finally: + for _ in range(pending_cancels): + task.cancel() + + +async def run_blocking_startup_work(work: Callable[[], Any]) -> None: + """Run blocking startup work off-loop without leaking cancellation races.""" + task = asyncio.create_task(asyncio.to_thread(work)) + try: + await asyncio.shield(task) + except asyncio.CancelledError: + with suspend_cancellation(): + while not task.done(): + try: + await asyncio.shield(task) + except asyncio.CancelledError: + continue + except Exception: + break + raise + + +async def cleanup_startup_cancellation(cleanup: Callable[[], Awaitable[None]]) -> None: + """Run startup cleanup without letting cleanup failures replace cancellation.""" + with suspend_cancellation(): + try: + await cleanup() + except BaseException as exc: + if isinstance(exc, (KeyboardInterrupt, SystemExit)): + raise + logger.error( + "Engine startup cleanup failed while preserving cancellation", + exc_info=(type(exc), exc, exc.__traceback__), + ) + + class BaseEngine(ABC): """ Abstract base class for inference engines. @@ -67,6 +127,15 @@ def preserve_native_tool_format(self) -> bool: def preserve_native_tool_format(self, value: bool) -> None: self._preserve_native_tool_format = value + def prepare_for_start(self) -> None: + """Run blocking startup work before async engine start. + + Engines can override this to perform heavyweight synchronous model + loads off the serving event loop. The default implementation is a + no-op so lightweight engines do not need extra plumbing. + """ + return None + @abstractmethod async def start(self) -> None: """Start the engine (load model if not loaded).""" @@ -196,3 +265,7 @@ def get_stats(self) -> dict[str, Any]: def get_cache_stats(self) -> dict[str, Any] | None: """Get cache statistics. Override in subclasses.""" return None + + def clear_runtime_caches(self) -> dict[str, Any] | None: + """Clear engine-managed runtime caches. Override in subclasses.""" + return None diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 809b4b1d1..a58b350b1 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -11,6 +11,7 @@ LLM engine), so text-only requests must also be routed through it. """ +import asyncio import logging import time from collections.abc import AsyncIterator @@ -18,7 +19,12 @@ from ..api.tool_calling import convert_tools_for_template from ..api.utils import clean_output_text, extract_multimodal_content, is_mllm_model -from .base import BaseEngine, GenerationOutput +from .base import ( + BaseEngine, + GenerationOutput, + cleanup_startup_cancellation, + run_blocking_startup_work, +) logger = logging.getLogger(__name__) @@ -134,7 +140,7 @@ class BatchedEngine(BaseEngine): def __init__( self, model_name: str, - trust_remote_code: bool = True, + trust_remote_code: bool = False, scheduler_config: Any | None = None, stream_interval: int = 1, force_mllm: bool = False, @@ -185,38 +191,99 @@ def tokenizer(self) -> Any: return getattr(self._processor, "tokenizer", self._processor) return self._tokenizer + def prepare_for_start(self) -> None: + """Load heavyweight model state off the serving event loop.""" + if self._model is not None: + return + + if self._is_mllm: + self._prepare_mllm_model() + else: + self._prepare_llm_model() + async def start(self) -> None: """Start the engine (load model if not loaded).""" if self._loaded: return - if self._is_mllm: - await self._start_mllm() - else: - await self._start_llm() + try: + if self._model is None: + if self._uses_default_prepare_for_start(): + # Load inline on the event-loop thread so mlx-lm's + # generation_stream (created at module import on this + # thread) and model weights are owned by the same thread + # that drives scheduler.step (issue #407). + self.prepare_for_start() + else: + # Test doubles and custom overrides may block; run them via + # the shared cancellation-safe thread helper. + await run_blocking_startup_work(self.prepare_for_start) + + if self._is_mllm: + await self._start_mllm() + else: + await self._start_llm() - self._loaded = True - logger.info(f"BatchedEngine loaded: {self._model_name} (mllm={self._is_mllm})") + self._loaded = True + logger.info( + f"BatchedEngine loaded: {self._model_name} (mllm={self._is_mllm})" + ) + except asyncio.CancelledError: + await cleanup_startup_cancellation(self.stop) + raise - async def _start_mllm(self) -> None: - """Start the MLLM engine with MLLMScheduler (continuous batching).""" - from ..mllm_scheduler import MLLMScheduler, MLLMSchedulerConfig + def _uses_default_prepare_for_start(self) -> bool: + """Return True when prepare_for_start is the class implementation.""" + method = getattr(self.prepare_for_start, "__func__", None) + return method is BatchedEngine.prepare_for_start + + def _prepare_mllm_model(self) -> None: + """Load the MLLM model before scheduler startup.""" from ..models.mllm import MLXMultimodalLM - # Load the MLLM model self._mllm_instance = MLXMultimodalLM( self._model_name, trust_remote_code=self._trust_remote_code, ) self._mllm_instance.load() - self._model = self._mllm_instance.model self._processor = self._mllm_instance.processor + # Set Metal memory limits (same as LLM path) + try: + import mlx.core as mx + + if mx.metal.is_available(): + device_info = mx.device_info() + max_recommended = device_info.get( + "max_recommended_working_set_size", + device_info.get("memory_size", 0), + ) + if max_recommended > 0: + soft_limit = int(max_recommended * self._gpu_memory_utilization) + mx.set_memory_limit(soft_limit) + mx.set_cache_limit(32 * 1024 * 1024 * 1024) # 32GB + pct = self._gpu_memory_utilization * 100 + logger.info( + f"Metal memory limits set: " + f"allocation_limit={soft_limit / 1e9:.1f}GB " + f"({pct:.0f}% of {max_recommended / 1e9:.1f}GB), " + f"cache_limit=32GB" + ) + except Exception as e: + logger.warning(f"Failed to set Metal memory limits: {e}") + # Inject MTP support if enabled if self._scheduler_config and self._scheduler_config.enable_mtp: self._inject_mtp_mllm() + async def _start_mllm(self) -> None: + """Start the MLLM engine with MLLMScheduler (continuous batching).""" + from ..mllm_scheduler import MLLMScheduler, MLLMSchedulerConfig + + if self._model is None or self._processor is None: + self._prepare_mllm_model() + # Create MLLM scheduler config with batch generator support if self._scheduler_config and hasattr(self._scheduler_config, "max_num_seqs"): max_num_seqs = self._scheduler_config.max_num_seqs @@ -240,11 +307,13 @@ async def _start_mllm(self) -> None: self._scheduler_config, "kv_cache_quantization_group_size", 64 ) - # Forward MLLM prefill-step override only when explicitly configured. - # This keeps default behavior unchanged for MLLM (1024) unless set. prefill_step_size = getattr( self._scheduler_config, "mllm_prefill_step_size", None ) + if prefill_step_size is None: + prefill_step_size = getattr( + self._scheduler_config, "prefill_step_size", None + ) mllm_extra = {} if prefill_step_size is not None: mllm_extra["prefill_step_size"] = prefill_step_size @@ -326,12 +395,13 @@ def _inject_mtp_mllm(self) -> None: else: logger.info(f"[MTP-MLLM] MTP not supported for model_type={model_type}") - async def _start_llm(self) -> None: - """Start the LLM engine with AsyncEngineCore.""" - from ..engine_core import AsyncEngineCore, EngineConfig - from ..scheduler import SchedulerConfig + def _prepare_llm_model(self) -> None: + """Load the LLM model/tokenizer before engine loop startup.""" from ..utils.tokenizer import load_model_with_fallback + if self._model is not None and self._tokenizer is not None: + return + # Build tokenizer config tokenizer_config = {"trust_remote_code": self._trust_remote_code} @@ -357,8 +427,10 @@ async def _start_llm(self) -> None: "See warnings above for details." ) - # Set Metal memory limits to make allocation failures graceful - # instead of fatal Metal command buffer errors (SIGABRT) + self._configure_metal_memory_limits() + + def _configure_metal_memory_limits(self) -> None: + """Make MLX allocation failures graceful during startup.""" try: import mlx.core as mx @@ -382,6 +454,26 @@ async def _start_llm(self) -> None: except Exception as e: logger.warning(f"Failed to set Metal memory limits: {e}") + async def _start_llm(self) -> None: + """Start the LLM engine with AsyncEngineCore.""" + from ..engine_core import AsyncEngineCore, EngineConfig + from ..scheduler import SchedulerConfig + + if self._model is None or self._tokenizer is None: + self._prepare_llm_model() + + # Validate MTP support if enabled + if self._scheduler_config and self._scheduler_config.enable_mtp: + from ..patches.qwen3_next_mtp import validate_mtp_support + + if validate_mtp_support(self._model): + logger.info("[MTP] Model validated for MTP speculative decoding") + else: + logger.warning( + "[MTP] MTP validation failed — --enable-mtp will be ignored. " + "See warnings above for details." + ) + # Create engine config scheduler_config = self._scheduler_config or SchedulerConfig() engine_config = EngineConfig( @@ -423,6 +515,7 @@ def _apply_chat_template( messages: list[dict[str, Any]], tools: list[dict] | None = None, num_images: int = 0, + chat_template_kwargs: dict[str, Any] | None = None, enable_thinking: bool | None = None, ) -> str: """Apply chat template to messages. @@ -460,20 +553,42 @@ def _apply_chat_template( "add_generation_prompt": True, "enable_thinking": enable_thinking, } - if tools: + if chat_template_kwargs: + template_kwargs.update(chat_template_kwargs) + if tools and "tools" not in template_kwargs: template_kwargs["tools"] = tools + tokenizer_applicator = None + tokenizer = self.tokenizer + if template_applicator is not tokenizer and hasattr( + tokenizer, "apply_chat_template" + ): + tokenizer_applicator = tokenizer + try: return template_applicator.apply_chat_template( messages, **template_kwargs ) + except ValueError as e: + # Some HF processors define apply_chat_template but do not carry + # a template (e.g. Gemma-3 processor). Retry on tokenizer. + if ( + tokenizer_applicator is not None + and "does not have a chat template" in str(e) + ): + return tokenizer_applicator.apply_chat_template( + messages, **template_kwargs + ) + raise except TypeError as e: - # Some templates don't accept 'tools' or 'enable_thinking'; - # retry without them. + # Some templates don't accept extra kwargs; retry without them. logger.debug(f"Chat template TypeError, retrying without extras: {e}") - for key in ["tools", "enable_thinking"]: - if key in template_kwargs: - del template_kwargs[key] + for key in [ + "tools", + "enable_thinking", + *(chat_template_kwargs or {}).keys(), + ]: + template_kwargs.pop(key, None) return template_applicator.apply_chat_template( messages, **template_kwargs ) @@ -509,7 +624,7 @@ def _prepare_mllm_messages( for part in content: if isinstance(part, dict) and part.get("type") == "image_url": new_content.append({"type": "image"}) - elif isinstance(part, (dict, str)): + elif isinstance(part, (dict | str)): new_content.append(part) # skip non-dict/non-str parts to avoid passing unexpected types prepared.append({**msg, "content": new_content}) @@ -562,6 +677,7 @@ async def generate( min_p=kwargs.pop("min_p", 0.0), presence_penalty=kwargs.pop("presence_penalty", 0.0), repetition_penalty=kwargs.pop("repetition_penalty", 1.0), + logits_processors=kwargs.pop("logits_processors", None), ) return GenerationOutput( @@ -584,6 +700,7 @@ async def generate( presence_penalty=kwargs.pop("presence_penalty", 0.0), repetition_penalty=kwargs.pop("repetition_penalty", 1.0), stop=stop or [], + logits_processors=kwargs.pop("logits_processors", None), ) output = await self._engine.generate( @@ -644,6 +761,7 @@ async def stream_generate( min_p=kwargs.pop("min_p", 0.0), presence_penalty=kwargs.pop("presence_penalty", 0.0), repetition_penalty=kwargs.pop("repetition_penalty", 1.0), + logits_processors=kwargs.pop("logits_processors", None), ) async for output in self._mllm_scheduler.stream_outputs(request_id): @@ -669,6 +787,7 @@ async def stream_generate( presence_penalty=kwargs.pop("presence_penalty", 0.0), repetition_penalty=kwargs.pop("repetition_penalty", 1.0), stop=stop or [], + logits_processors=kwargs.pop("logits_processors", None), ) prefix_boundary = kwargs.pop("prefix_boundary", 0) @@ -732,6 +851,7 @@ async def chat( # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None + chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {}) # Per-request enable_thinking override enable_thinking = kwargs.pop("enable_thinking", None) @@ -741,6 +861,7 @@ async def chat( messages, template_tools, num_images=len(all_images), + chat_template_kwargs=chat_template_kwargs, enable_thinking=enable_thinking, ) @@ -755,7 +876,10 @@ async def chat( ) def _compute_prefix_boundary( - self, messages: list[dict[str, Any]], tools: list[dict] | None = None + self, + messages: list[dict[str, Any]], + tools: list[dict] | None = None, + chat_template_kwargs: dict[str, Any] | None = None, ) -> int: """Compute token count for the shared prefix across message variations. @@ -777,7 +901,11 @@ def _compute_prefix_boundary( template_tools = convert_tools_for_template(tools) if tools else None # Tokenize the real prompt - real_prompt = self._apply_chat_template(messages, template_tools) + real_prompt = self._apply_chat_template( + messages, + template_tools, + chat_template_kwargs=chat_template_kwargs, + ) # Build a dummy variant with different last user content dummy_messages = list(messages) @@ -785,7 +913,11 @@ def _compute_prefix_boundary( **messages[last_user_idx], "content": "XXXXXXXXXX", } - dummy_prompt = self._apply_chat_template(dummy_messages, template_tools) + dummy_prompt = self._apply_chat_template( + dummy_messages, + template_tools, + chat_template_kwargs=chat_template_kwargs, + ) tokenizer = self.tokenizer if hasattr(tokenizer, "tokenizer"): @@ -847,6 +979,7 @@ async def stream_chat( # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None + chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {}) # Per-request enable_thinking override enable_thinking = kwargs.pop("enable_thinking", None) @@ -856,11 +989,16 @@ async def stream_chat( messages, template_tools, num_images=len(all_images), + chat_template_kwargs=chat_template_kwargs, enable_thinking=enable_thinking, ) # Compute prefix boundary for cache - prefix_boundary = self._compute_prefix_boundary(messages, tools) + prefix_boundary = self._compute_prefix_boundary( + messages, + tools, + chat_template_kwargs=chat_template_kwargs, + ) if prefix_boundary > 0: kwargs["prefix_boundary"] = prefix_boundary @@ -918,11 +1056,22 @@ def get_stats(self) -> dict[str, Any]: def get_cache_stats(self) -> dict[str, Any] | None: """Get cache statistics.""" if self._mllm_scheduler and self._mllm_scheduler.batch_generator: - return self._mllm_scheduler.batch_generator.get_vision_cache_stats() + return { + "prefix_cache": self._mllm_scheduler.batch_generator.get_prefix_cache_stats(), + "vision_embedding_cache": self._mllm_scheduler.batch_generator.get_vision_cache_stats(), + } elif self._engine: return self._engine.get_cache_stats() return None + def clear_runtime_caches(self) -> dict[str, Any] | None: + """Clear engine-managed runtime caches.""" + if self._mllm_scheduler is not None: + return self._mllm_scheduler.clear_runtime_caches() + if self._engine is not None: + return self._engine.clear_runtime_caches() + return None + def save_cache_to_disk(self, cache_dir: str) -> bool: """Save prefix cache to disk for persistence across restarts.""" if self._mllm_scheduler and self._mllm_scheduler.batch_generator: @@ -935,10 +1084,22 @@ def save_cache_to_disk(self, cache_dir: str) -> bool: def load_cache_from_disk(self, cache_dir: str) -> int: """Load prefix cache from disk. Returns number of entries loaded.""" - if self._mllm_scheduler and self._mllm_scheduler.batch_generator: + if self._mllm_scheduler: + self._mllm_scheduler._ensure_batch_generator() pc = self._mllm_scheduler.batch_generator.prefix_cache if pc is not None: return pc.load_from_disk(cache_dir) if self._engine: return self._engine.load_cache_from_disk(cache_dir) return 0 + + def clear_prefix_cache(self) -> None: + """Clear the in-memory prefix cache. Used by bench-serve for clean + cold-start measurements between configurations.""" + if self._mllm_scheduler and self._mllm_scheduler.batch_generator: + pc = self._mllm_scheduler.batch_generator.prefix_cache + if pc is not None and hasattr(pc, "clear"): + pc.clear() + return + if self._engine and hasattr(self._engine, "clear_prefix_cache"): + self._engine.clear_prefix_cache() diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 2235dda14..18b9c5a9b 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -8,38 +8,99 @@ import asyncio import logging +import os +import threading import time from collections.abc import AsyncIterator from typing import Any +import mlx.core as mx + from ..api.tool_calling import convert_tools_for_template -from ..api.utils import clean_output_text, is_mllm_model -from .base import BaseEngine, GenerationOutput +from ..api.utils import clean_output_text, has_media_content, is_mllm_model +from .base import ( + BaseEngine, + GenerationOutput, + cleanup_startup_cancellation, + run_blocking_startup_work, +) +from ..mlx_streams import bind_generation_streams logger = logging.getLogger(__name__) -_MEDIA_TYPES = frozenset( - { - "image_url", - "video_url", - "audio_url", - "image", - "video", - "audio", - } -) +def _bind_worker_generation_streams() -> None: + """Rebind mlx generation streams inside the current worker thread.""" + bind_generation_streams() -def _has_media_content(messages: list) -> bool: - """Check if any message contains media content (images, video, audio).""" - for msg in messages: - content = msg.get("content") - if isinstance(content, list): - for part in content: - if isinstance(part, dict) and part.get("type") in _MEDIA_TYPES: - return True - return False +def _seed_logits_processors( + seed_tokens: mx.array | None, + processors: list[Any] | None, +) -> list[Any] | None: + """Wrap logits processors so continuation decode sees the full prompt.""" + if not processors: + return None + if seed_tokens is None or seed_tokens.size == 0: + return list(processors) + + def _wrap(processor): + def _seeded(tokens, logits): + merged = seed_tokens + if tokens is not None: + if not isinstance(tokens, mx.array): + tokens_arr = mx.array(tokens, dtype=mx.uint32) + else: + tokens_arr = tokens + if tokens_arr.size > 0: + merged = mx.concatenate([seed_tokens, tokens_arr]) + return processor(merged, logits) + + return _seeded + + return [_wrap(processor) for processor in processors] + + +def _sample_with_processors( + tokens: mx.array | None, + logits: mx.array, + sampler: Any, + logits_processors: list[Any] | None, +) -> tuple[mx.array, mx.array]: + """Sample a token while honoring any active logits processors.""" + if logits_processors: + is_1d = logits.ndim == 1 + if is_1d: + logits = logits[None] + for processor in logits_processors: + logits = processor(tokens, logits) + if is_1d: + logits = logits.squeeze(0) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + tok = sampler(logprobs) + return tok, logprobs + + +def _processors_can_retire(processors: list[Any] | None) -> bool: + """True when any processor advertises a retire-to-content transition.""" + if os.getenv("VLLM_MLX_ENABLE_THINKING_RETIREMENT_RESUME") != "1": + return False + return bool(processors) and any( + isinstance(getattr(p, "is_retired", None), bool) for p in processors + ) + + +def _processors_retired(processors: list[Any] | None) -> bool: + """True when any retire-capable processor has entered its retired state.""" + if os.getenv("VLLM_MLX_ENABLE_THINKING_RETIREMENT_RESUME") != "1": + return False + return bool(processors) and any( + getattr(p, "is_retired", False) is True for p in processors + ) + + +class _SpecPrefillCancelled(Exception): + """Cooperative cancellation sentinel for blocking SpecPrefill workers.""" class SimpleEngine(BaseEngine): @@ -53,10 +114,11 @@ class SimpleEngine(BaseEngine): def __init__( self, model_name: str, - trust_remote_code: bool = True, + trust_remote_code: bool = False, enable_cache: bool = True, force_mllm: bool = False, mtp: bool = False, + mtp_num_draft_tokens: int = 1, prefill_step_size: int = 2048, specprefill_enabled: bool = False, specprefill_threshold: int = 8192, @@ -72,6 +134,7 @@ def __init__( enable_cache: Enable VLM cache for multimodal models force_mllm: Force loading as MLLM even if not auto-detected mtp: Enable native MTP speculative decoding (model must have MTP head) + mtp_num_draft_tokens: Draft tokens per speculative MTP step prefill_step_size: Chunk size for prompt prefill processing (default: 2048) specprefill_enabled: Enable SpecPrefill (attention-based sparse prefill) specprefill_threshold: Minimum suffix tokens to trigger SpecPrefill @@ -84,6 +147,7 @@ def __init__( self._enable_cache = enable_cache self._is_mllm = force_mllm or is_mllm_model(model_name) self._mtp = mtp + self._mtp_num_draft_tokens = mtp_num_draft_tokens self._prefill_step_size = prefill_step_size # SpecPrefill config @@ -129,9 +193,9 @@ def tokenizer(self) -> Any: return getattr(self._model, "processor", None) return self._model.tokenizer - async def start(self) -> None: - """Start the engine (load model if not loaded).""" - if self._loaded: + def prepare_for_start(self) -> None: + """Load the backing model off the serving event loop.""" + if self._model is not None: return if self._is_mllm: @@ -149,72 +213,101 @@ async def start(self) -> None: self._model_name, trust_remote_code=self._trust_remote_code, mtp=self._mtp, + mtp_num_draft_tokens=self._mtp_num_draft_tokens, ) self._model.load() - self._loaded = True - # Build parallel mlx_lm TextModel for text-only MTP routing - if self._is_mllm and self._mtp: - try: - from ..text_model_from_vlm import build_text_model + async def start(self) -> None: + """Start the engine (load model if not loaded).""" + if self._loaded: + return + try: + if self._model is None: + await run_blocking_startup_work(self.prepare_for_start) + self._loaded = True + + if self._mtp and self._mtp_num_draft_tokens != 1: + logger.warning( + "Native mlx_lm MTP currently ignores num_draft_tokens=%d; " + "effective speculative draft depth remains 1", + self._mtp_num_draft_tokens, + ) - self._text_model = build_text_model(self._model.model, self._model_name) + # Build parallel mlx_lm TextModel for text-only routing. + # Even when MTP is disabled, text-only requests should not be trapped + # on the slower mlx_vlm multimodal path. + if self._is_mllm: + try: + from ..text_model_from_vlm import build_text_model - if ( - self._text_model is not None - and hasattr(self._text_model, "mtp") - and self._text_model.mtp is not None - ): - self._text_tokenizer = self._model.get_tokenizer() + self._text_model = build_text_model( + self._model.model, self._model_name + ) + + if self._text_model is not None: + self._text_tokenizer = self._model.get_tokenizer() - # Apply Qwen3.5 eos_token fix (matches MLXLanguageModel.load) - if "qwen3" in self._model_name.lower(): - self._text_tokenizer.eos_token = "<|im_end|>" - self._text_tokenizer.eos_token_id = ( - self._text_tokenizer.convert_tokens_to_ids("<|im_end|>") + # Apply Qwen3.5 eos_token fix (matches MLXLanguageModel.load) + if "qwen3" in self._model_name.lower(): + self._text_tokenizer.eos_token = "<|im_end|>" + self._text_tokenizer.eos_token_id = ( + self._text_tokenizer.convert_tokens_to_ids("<|im_end|>") + ) + + has_mtp = ( + hasattr(self._text_model, "mtp") + and self._text_model.mtp is not None + ) + logger.info( + "MLLM text routing: text-only -> mlx_lm TextModel " + "(MTP=%s), media -> mlx_vlm", + has_mtp and self._mtp, ) + else: + self._text_model = None + self._text_tokenizer = None - logger.info( - "MLLM+MTP routing: text-only → mlx_lm TextModel (MTP=True), " - "media → mlx_vlm" - ) - else: - logger.warning( - "TextModel built but no MTP — text-only requests won't use MTP" - ) + except Exception as e: + logger.error("MLLM text routing setup failed: %s", e) self._text_model = None + self._text_tokenizer = None - except Exception as e: - logger.error("MLLM+MTP routing setup failed: %s", e) - self._text_model = None - self._text_tokenizer = None - - # Load SpecPrefill draft model (small model for importance scoring) - if self._specprefill_enabled and self._specprefill_draft_model_path: - try: - from mlx_lm import load as mlx_lm_load + # Load SpecPrefill draft model (small model for importance scoring) + if self._specprefill_enabled and self._specprefill_draft_model_path: + try: + from mlx_lm import load as mlx_lm_load - self._draft_model, _ = mlx_lm_load(self._specprefill_draft_model_path) - logger.info( - "SpecPrefill: draft model loaded (%s), threshold=%d, keep=%.0f%%", - self._specprefill_draft_model_path, - self._specprefill_threshold, - self._specprefill_keep_pct * 100, + self._draft_model, _ = mlx_lm_load( + self._specprefill_draft_model_path + ) + logger.info( + "SpecPrefill: draft model loaded (%s), threshold=%d, keep=%.0f%%", + self._specprefill_draft_model_path, + self._specprefill_threshold, + self._specprefill_keep_pct * 100, + ) + except Exception as e: + logger.error("SpecPrefill: draft model load failed: %s", e) + self._draft_model = None + + mtp_info = "" + if self._mtp: + mtp_info = ( + f", MTP={self._mtp}(configured={self._mtp_num_draft_tokens}, " + "effective=1)" ) - except Exception as e: - logger.error("SpecPrefill: draft model load failed: %s", e) - self._draft_model = None - - mtp_info = f", MTP={self._mtp}" if self._mtp else "" - routing = ", routing=per-request" if self._text_model is not None else "" - specprefill_info = ( - ", SpecPrefill=active" if self._draft_model is not None else "" - ) - logger.info( - f"SimpleEngine loaded: {self._model_name} " - f"(MLLM={self._is_mllm}{mtp_info}{routing}{specprefill_info})" - ) + routing = ", routing=per-request" if self._text_model is not None else "" + specprefill_info = ( + ", SpecPrefill=active" if self._draft_model is not None else "" + ) + logger.info( + f"SimpleEngine loaded: {self._model_name} " + f"(MLLM={self._is_mllm}{mtp_info}{routing}{specprefill_info})" + ) + except asyncio.CancelledError: + await cleanup_startup_cancellation(self.stop) + raise async def stop(self) -> None: """Stop the engine and cleanup resources.""" @@ -228,7 +321,7 @@ async def stop(self) -> None: self._system_kv_token_count = 0 logger.info("SimpleEngine stopped") - async def _run_blocking_serialized(self, func, /, *args, **kwargs): + async def _run_blocking_serialized(self, func, /, *args, on_cancel=None, **kwargs): """Run a blocking MLX operation under the generation lock. Cancellation must not release the async lock before the worker thread @@ -236,10 +329,23 @@ async def _run_blocking_serialized(self, func, /, *args, **kwargs): corrupt the command-buffer state. """ async with self._generation_lock: - task = asyncio.create_task(asyncio.to_thread(func, *args, **kwargs)) + + def run_bound(): + _bind_worker_generation_streams() + return func(*args, **kwargs) + + task = asyncio.create_task(asyncio.to_thread(run_bound)) try: return await asyncio.shield(task) except asyncio.CancelledError: + if on_cancel is not None: + try: + on_cancel() + except Exception: + logger.debug( + "Blocking worker cancellation callback failed", + exc_info=True, + ) try: await task except BaseException: @@ -470,11 +576,9 @@ async def chat( if not self._loaded: await self.start() - # mlx-lm non-streaming chat with tools can stall indefinitely on some - # local models, while the streaming path completes normally. Reuse the - # streaming implementation and aggregate its final state so both chat - # APIs share the same tool-capable execution path. - if tools and not self._is_mllm: + chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {}) + + async def aggregate_stream_chat() -> GenerationOutput: final_output = GenerationOutput(text="") async for output in self.stream_chat( messages=messages, @@ -484,6 +588,7 @@ async def chat( tools=tools, images=images, videos=videos, + chat_template_kwargs=chat_template_kwargs, **kwargs, ): final_output = output @@ -496,10 +601,30 @@ async def chat( finish_reason=final_output.finish_reason, ) + # mlx-lm non-streaming chat with tools can stall indefinitely on some + # local models, while the streaming path completes normally. Reuse the + # streaming implementation and aggregate its final state so both chat + # APIs share the same tool-capable execution path. + if tools and not self._is_mllm: + return await aggregate_stream_chat() + + # Text-only requests on MLLM models should use the TextModel route even + # for non-streaming chat. Aggregating the streaming path keeps one + # execution seam for text-only requests and avoids thread-local mlx_vlm + # stream assumptions inside to_thread(). + if ( + self._is_mllm + and self._text_model is not None + and not has_media_content(messages) + ): + return await aggregate_stream_chat() + # Convert tools for template if provided template_tools = convert_tools_for_template(tools) if tools else None if self._is_mllm: + if chat_template_kwargs: + kwargs["chat_template_kwargs"] = chat_template_kwargs output = await self._run_blocking_serialized( self._model.chat, messages=messages, @@ -523,6 +648,7 @@ async def chat( temperature=temperature, top_p=top_p, tools=template_tools, + chat_template_kwargs=chat_template_kwargs, **kwargs, ) text = clean_output_text(output.text) @@ -575,16 +701,23 @@ async def stream_chat( if not self._loaded: await self.start() + chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {}) + # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None - # Per-request routing: text-only through mlx_lm with MTP + # Per-request routing: text-only through mlx_lm TextModel if ( self._is_mllm and self._text_model is not None - and not _has_media_content(messages) + and not has_media_content(messages) ): - logger.info("Text-only request → LLM path (MTP=True)") + has_mtp = ( + hasattr(self._text_model, "mtp") and self._text_model.mtp is not None + ) + logger.info("Text-only request → LLM path (MTP=%s)", has_mtp and self._mtp) + if chat_template_kwargs: + kwargs["chat_template_kwargs"] = chat_template_kwargs async for chunk in self._stream_generate_text( messages, max_tokens, @@ -608,13 +741,16 @@ async def stream_chat( # Run stream_chat in thread pool since it's synchronous def run_stream(): + local_kwargs = dict(kwargs) + if chat_template_kwargs: + local_kwargs["chat_template_kwargs"] = chat_template_kwargs return list( self._model.stream_chat( messages=messages, max_tokens=max_tokens, temperature=temperature, tools=template_tools, - **kwargs, + **local_kwargs, ) ) @@ -652,6 +788,8 @@ def run_stream(): "add_generation_prompt": True, "enable_thinking": enable_thinking, } + if chat_template_kwargs: + template_kwargs.update(chat_template_kwargs) if template_tools: template_kwargs["tools"] = template_tools @@ -659,7 +797,7 @@ def run_stream(): prompt = tokenizer.apply_chat_template(messages, **template_kwargs) except TypeError: # Some templates don't support all kwargs - for key in ["tools", "enable_thinking"]: + for key in ["tools", "enable_thinking", *chat_template_kwargs.keys()]: if key in template_kwargs: del template_kwargs[key] prompt = tokenizer.apply_chat_template(messages, **template_kwargs) @@ -694,17 +832,25 @@ async def _stream_generate_specprefill( model, then generates autoregressively. Falls back to normal generation on any error. """ - import mlx.core as mx - from mlx_lm.models.cache import make_prompt_cache - from mlx_lm.sample_utils import make_sampler + from threading import Event model = self._model.model tokenizer = self._model.tokenizer n_tokens = len(tokens) + cancel_requested = Event() + + def _request_cancel() -> None: + cancel_requested.set() + + def _cancel_check() -> None: + if cancel_requested.is_set(): + raise _SpecPrefillCancelled() def _run_all(): try: return _run_specprefill() + except _SpecPrefillCancelled: + raise except Exception as e: logger.error("SpecPrefill failed, falling back to normal path: %s", e) return _run_normal() @@ -714,6 +860,10 @@ def _run_specprefill(): import time from types import SimpleNamespace + import mlx.core as mx + from mlx_lm.models.cache import make_prompt_cache + from mlx_lm.sample_utils import make_sampler + from ..specprefill import ( cleanup_rope, score_tokens, @@ -730,10 +880,12 @@ def _run_specprefill(): self._draft_model, tokens, prefill_step_size=self._prefill_step_size, + cancel_check=_cancel_check, ) t_score = time.monotonic() - t0 # Phase 2: Select important chunks + _cancel_check() effective_keep = specprefill_keep_pct or self._specprefill_keep_pct selected = select_chunks(importance, keep_pct=effective_keep) n_selected = selected.shape[0] @@ -746,6 +898,7 @@ def _run_specprefill(): selected, cache, step_size=self._prefill_step_size, + cancel_check=_cancel_check, ) t_prefill = time.monotonic() - t0 @@ -760,38 +913,37 @@ def _run_specprefill(): t_prefill, ) - # Phase 4: Generate (simple autoregressive, no MTP) + # Phase 4: Generate via engine's standard pipelined path sampler = make_sampler(temp=temperature, top_p=top_p) + _cancel_check() + first_token_id = sampler(logits[:, -1, :]).item() + first_text = tokenizer.decode([first_token_id]) eos_id = tokenizer.eos_token_id - y = sampler(logits[:, -1, :]) - mx.eval(y) - results = [] - generated_ids = [] - prev_decoded = "" - - for _ in range(max_tokens): - tok_id = y.item() - generated_ids.append(tok_id) - - decoded = tokenizer.decode(generated_ids) - new_text = decoded[len(prev_decoded) :] - prev_decoded = decoded - - is_eos = tok_id == eos_id - results.append( - SimpleNamespace( - text=new_text, - finish_reason="stop" if is_eos else None, - ) + results = [ + SimpleNamespace( + text=first_text, + finish_reason="stop" if first_token_id == eos_id else None, ) + ] - if is_eos: - break - - logits = model(y.reshape(1, -1), cache=cache) - y = sampler(logits[:, -1, :]) - mx.eval(y) + if first_token_id != eos_id: + for chunk in self._model.stream_generate( + prompt=mx.array([first_token_id]), + max_tokens=max_tokens - 1, + temperature=temperature, + top_p=top_p, + stop=stop, + prompt_cache=cache, + ): + _cancel_check() + new_text = chunk.text if hasattr(chunk, "text") else str(chunk) + results.append( + SimpleNamespace( + text=new_text, + finish_reason=getattr(chunk, "finish_reason", None), + ) + ) return results @@ -811,6 +963,7 @@ def _run_normal(): stop=stop, **kwargs, ): + _cancel_check() new_text = chunk.text if hasattr(chunk, "text") else str(chunk) results.append( SimpleNamespace( @@ -820,7 +973,9 @@ def _run_normal(): ) return results - all_resps = await self._run_blocking_serialized(_run_all) + all_resps = await self._run_blocking_serialized( + _run_all, on_cancel=_request_cancel + ) # Yield results as GenerationOutput accumulated_text = "" @@ -865,9 +1020,9 @@ async def _stream_generate_text( tools: list | None = None, **kwargs, ) -> AsyncIterator[GenerationOutput]: - """Text-only generation via mlx_lm TextModel with MTP. + """Text-only generation via mlx_lm TextModel. - Used when MLLM+MTP routing is active and the request has no media. + Used when text-only MLLM routing is active and the request has no media. Runs the full generation in a single thread to maintain Metal safety. System prompt KV caching: on the first request, prefills system tokens @@ -879,12 +1034,21 @@ async def _stream_generate_text( import mlx.core as mx from mlx_lm import stream_generate as mlx_stream_generate + from mlx_lm.models import cache as cache_module from mlx_lm.models.cache import make_prompt_cache - from mlx_lm.sample_utils import make_sampler + from mlx_lm.sample_utils import make_logits_processors, make_sampler # Per-request specprefill overrides (from extra_body) specprefill_override = kwargs.pop("specprefill", None) specprefill_keep_pct = kwargs.pop("specprefill_keep_pct", None) + chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {}) + top_k = kwargs.pop("top_k", 0) + min_p = kwargs.pop("min_p", 0.0) + presence_penalty = kwargs.pop("presence_penalty", 0.0) + repetition_penalty = kwargs.pop("repetition_penalty", 1.0) + stop = kwargs.pop("stop", None) + external_logits_processors = kwargs.pop("logits_processors", None) + abort_event = threading.Event() # Per-request enable_thinking override; fall back to env var / default True. enable_thinking = kwargs.pop("enable_thinking", None) @@ -898,6 +1062,7 @@ async def _stream_generate_text( "add_generation_prompt": True, "enable_thinking": enable_thinking, } + template_kwargs.update(chat_template_kwargs) if tools: template_kwargs["tools"] = tools @@ -913,8 +1078,20 @@ async def _stream_generate_text( messages, **template_kwargs ) - # Build sampler - sampler = make_sampler(temp=temperature, top_p=top_p) + sampler = make_sampler( + temp=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + ) + penalty_processors = make_logits_processors( + repetition_penalty=( + repetition_penalty if repetition_penalty != 1.0 else None + ), + presence_penalty=presence_penalty if presence_penalty != 0.0 else None, + ) + all_processors = (external_logits_processors or []) + (penalty_processors or []) + custom_logits_active = bool(all_processors) max_tokens = max_tokens or 4096 # --- System prompt KV caching --- @@ -979,12 +1156,29 @@ async def _stream_generate_text( and system_token_count == self._system_kv_token_count ): # Cache HIT — restore KV state into fresh backbone cache - backbone_cache = make_prompt_cache(self._text_model) - for i, saved_state in enumerate(self._system_kv_snapshot): - backbone_cache[i].state = saved_state - - prompt_to_send = mx.array(suffix_tokens) + def make_cache_with_snapshot( + text_model, + system_kv_snapshot, + ): + import mlx.core as mx + from mlx_lm.models.cache import make_prompt_cache + + backbone_cache = make_prompt_cache(text_model) + for i, saved_state in enumerate(system_kv_snapshot): + backbone_cache[i].state = saved_state + + prompt_to_send = mx.array(suffix_tokens) + return backbone_cache, prompt_to_send + + backbone_cache, prompt_to_send = ( + await self._run_blocking_serialized( + make_cache_with_snapshot, + self._text_model, + self._system_kv_snapshot, + ) + ) cache_hit = True + logger.info( "System KV cache HIT: reusing %d cached tokens, " "prefilling %d new tokens (hash=%s)", @@ -1063,11 +1257,76 @@ async def _stream_generate_text( ) use_specprefill = False + loop = asyncio.get_running_loop() + response_queue: asyncio.Queue[tuple[str, Any]] = asyncio.Queue() + + def _emit_response(resp: Any) -> None: + if abort_event.is_set(): + return + loop.call_soon_threadsafe(response_queue.put_nowait, ("resp", resp)) + + def _emit_done() -> None: + loop.call_soon_threadsafe(response_queue.put_nowait, ("done", None)) + + def _emit_error(exc: BaseException) -> None: + loop.call_soon_threadsafe(response_queue.put_nowait, ("error", exc)) + + def _seed_from_last_response(prompt_cache, last_resp): + last_tok = getattr(last_resp, "token", None) + if last_tok is not None: + cache_module.trim_prompt_cache(prompt_cache, 1) + return mx.array([last_tok], dtype=mx.uint32) + return mx.array( + self._text_tokenizer.encode(getattr(last_resp, "text", "")), + dtype=mx.uint32, + ) + + def _resume_after_processor_retirement( + model, + prompt_cache, + prompt, + remaining_tokens: int, + ) -> None: + resume_kwargs = dict( + max_tokens=remaining_tokens, + sampler=sampler, + prefill_step_size=self._prefill_step_size, + prompt_cache=prompt_cache, + ) + if hasattr(model, "make_mtp_cache") and model.mtp is not None: + # Resume speculative decode from the retained backbone cache with + # a fresh MTP cache so stale speculative state cannot survive the + # processor-to-content handoff. + resume_kwargs["prompt_cache"] = prompt_cache + model.make_mtp_cache() + resume_kwargs["mtp"] = True + resume_kwargs["num_draft_tokens"] = self._mtp_num_draft_tokens + for resp in mlx_stream_generate( + model, + self._text_tokenizer, + prompt=prompt, + **resume_kwargs, + ): + if abort_event.is_set(): + logger.info("Text route: abort requested; stopping resume decode") + break + _emit_response(resp) + # Run all Metal ops in a single serialized thread. def _run_all(): nonlocal backbone_cache, prompt_to_send model = self._text_model + can_retire_processors = _processors_can_retire(all_processors) + use_mtp = ( + self._mtp + and not custom_logits_active + and hasattr(model, "mtp") + and model.mtp is not None + ) + if self._mtp and custom_logits_active: + logger.info( + "Text route: disabling MTP for request-local logits processors" + ) # Cache MISS with valid prefix: prefill system tokens and snapshot if ( @@ -1111,7 +1370,8 @@ def _run_all(): # --- SpecPrefill path (with fallback to normal on failure) --- if use_specprefill: try: - return _run_specprefill(model, backbone_cache) + _run_specprefill(model, backbone_cache, use_mtp) + return except Exception as e: logger.error( "SpecPrefill failed, falling back to normal MTP path: %s", @@ -1121,39 +1381,89 @@ def _run_all(): backbone_cache = None prompt_to_send = full_prompt - # --- Normal path (MTP via mlx_lm stream_generate) --- + # --- Normal path (mlx_lm stream_generate) --- prompt_cache = None if backbone_cache is not None: # Add MTP cache on top of backbone - if hasattr(model, "make_mtp_cache"): + if use_mtp and hasattr(model, "make_mtp_cache"): mtp_cache = model.make_mtp_cache() prompt_cache = backbone_cache + mtp_cache else: prompt_cache = backbone_cache - results = [] gen_kwargs = dict( max_tokens=max_tokens, sampler=sampler, - mtp=True, prefill_step_size=self._prefill_step_size, ) + if all_processors: + gen_kwargs["logits_processors"] = all_processors + if use_mtp: + gen_kwargs["mtp"] = True + gen_kwargs["num_draft_tokens"] = self._mtp_num_draft_tokens if prompt_cache is not None: gen_kwargs["prompt_cache"] = prompt_cache + if can_retire_processors and not use_mtp: + shared_cache = prompt_cache + if shared_cache is None: + shared_cache = make_prompt_cache(model) + gen_kwargs["prompt_cache"] = shared_cache + + token_count = 0 + last_resp = None + retired = False + for resp in mlx_stream_generate( + model, + self._text_tokenizer, + prompt=prompt_to_send, + **gen_kwargs, + ): + if abort_event.is_set(): + logger.info( + "Text route: abort requested; stopping decode after %d tokens", + token_count, + ) + break + _emit_response(resp) + token_count += 1 + last_resp = resp + retired = _processors_retired(all_processors) + if retired: + logger.info( + "Text route: request-local processor retired after %d tokens; " + "resuming content phase with MTP=%s", + token_count, + hasattr(model, "make_mtp_cache") and model.mtp is not None, + ) + break - for resp in mlx_stream_generate( - model, - self._text_tokenizer, - prompt=prompt_to_send, - **gen_kwargs, - ): - results.append(resp) - return results + if retired and token_count < max_tokens and last_resp is not None: + seed = _seed_from_last_response(shared_cache, last_resp) + _resume_after_processor_retirement( + model, + shared_cache, + seed, + max_tokens - token_count, + ) + else: + for resp in mlx_stream_generate( + model, + self._text_tokenizer, + prompt=prompt_to_send, + **gen_kwargs, + ): + if abort_event.is_set(): + logger.info("Text route: abort requested; stopping decode") + break + _emit_response(resp) - def _run_specprefill(model, bc): - """Score tokens, sparse prefill, generate without MTP.""" + def _run_specprefill(model, bc, use_mtp): + """Score tokens, sparse prefill, then continue on the standard decode path.""" from types import SimpleNamespace + from mlx_lm import stream_generate as mlx_stream_generate + from mlx_lm.models.cache import make_prompt_cache + from ..specprefill import ( cleanup_rope, score_tokens, @@ -1209,71 +1519,173 @@ def _run_specprefill(model, bc): effective_keep, ) - # Phase 4: Generate (simple autoregressive, no MTP) + # Phase 4: Sample the first token from the prefilled logits, then + # continue through mlx_lm's normal decode path so MTP and request- + # local logits processors remain active after sparse prefill. eos_id = self._text_tokenizer.eos_token_id - y = sampler(logits[:, -1, :]) + seed_tokens = ( + mx.array(full_tokens_list, dtype=mx.uint32) + if full_tokens_list is not None + else None + ) + seeded_processors = _seed_logits_processors(seed_tokens, all_processors) + y, _ = _sample_with_processors( + None, + logits[:, -1, :].squeeze(0), + sampler, + seeded_processors, + ) mx.eval(y) - results = [] generated_ids = [] prev_decoded = "" - for _ in range(max_tokens): - tok_id = y.item() - generated_ids.append(tok_id) + tok_id = y.item() + generated_ids.append(tok_id) - # Incremental text decode - decoded = self._text_tokenizer.decode(generated_ids) - new_text = decoded[len(prev_decoded) :] - prev_decoded = decoded + decoded = self._text_tokenizer.decode(generated_ids) + new_text = decoded[len(prev_decoded) :] + prev_decoded = decoded - is_eos = tok_id == eos_id - results.append( - SimpleNamespace( - text=new_text, - finish_reason="stop" if is_eos else None, - ) + is_eos = tok_id == eos_id + _emit_response( + SimpleNamespace( + text=new_text, + finish_reason="stop" if is_eos else None, ) + ) - if is_eos: - break + if abort_event.is_set(): + logger.info( + "SpecPrefill text route: abort requested after seed token" + ) + return - # Next token - logits = model(y.reshape(1, -1), cache=bc) - y = sampler(logits[:, -1, :]) - mx.eval(y) + if is_eos or max_tokens <= 1: + return - return results + prompt_cache = bc + if use_mtp and hasattr(model, "make_mtp_cache"): + prompt_cache = bc + model.make_mtp_cache() + + continuation_prompt = mx.array([tok_id], dtype=mx.uint32) + token_count = 1 + if _processors_retired(all_processors) and token_count < max_tokens: + logger.info( + "SpecPrefill text route: request-local processor retired after seed token; " + "resuming content phase with MTP=%s", + hasattr(model, "make_mtp_cache") and model.mtp is not None, + ) + _resume_after_processor_retirement( + model, + bc, + continuation_prompt, + max_tokens - token_count, + ) + return + + last_resp = None + retired = False + for resp in mlx_stream_generate( + model, + self._text_tokenizer, + prompt=continuation_prompt, + max_tokens=max_tokens - token_count, + sampler=sampler, + prefill_step_size=self._prefill_step_size, + logits_processors=seeded_processors, + prompt_cache=prompt_cache, + mtp=use_mtp, + ): + if abort_event.is_set(): + logger.info( + "SpecPrefill text route: abort requested; stopping decode" + ) + break + _emit_response(resp) + token_count += 1 + last_resp = resp + retired = _processors_retired(all_processors) + if retired: + logger.info( + "SpecPrefill text route: request-local processor retired after %d tokens; " + "resuming content phase with MTP=%s", + token_count, + hasattr(model, "make_mtp_cache") and model.mtp is not None, + ) + break + + if retired and token_count < max_tokens and last_resp is not None: + seed = _seed_from_last_response(bc, last_resp) + _resume_after_processor_retirement( + model, + bc, + seed, + max_tokens - token_count, + ) finally: cleanup_rope(model) - all_resps = await self._run_blocking_serialized(_run_all) + async def _produce_responses() -> None: + try: + await self._run_blocking_serialized( + _run_all, + on_cancel=abort_event.set, + ) + except asyncio.CancelledError: + raise + except BaseException as exc: + _emit_error(exc) + else: + _emit_done() + + producer_task = asyncio.create_task(_produce_responses()) # Yield results as GenerationOutput accumulated_text = "" token_count = 0 finished = False - for i, resp in enumerate(all_resps): - token_count += 1 - new_text = resp.text if hasattr(resp, "text") else str(resp) - accumulated_text += new_text + try: + while True: + kind, payload = await response_queue.get() + if kind == "done": + break + if kind == "error": + raise payload + resp = payload - is_last = i == len(all_resps) - 1 - finished = is_last or token_count >= max_tokens + token_count += 1 + new_text = resp.text if hasattr(resp, "text") else str(resp) + accumulated_text += new_text - yield GenerationOutput( - text=accumulated_text, - new_text=new_text, - prompt_tokens=full_token_count or 0, - completion_tokens=token_count, - finished=finished, - finish_reason=getattr(resp, "finish_reason", None) - or ("stop" if finished else None), - ) + stop_hit = False + if stop: + stop_hit = any(stop_seq in accumulated_text for stop_seq in stop) + finished = stop_hit or token_count >= max_tokens + finish_reason = getattr(resp, "finish_reason", None) + if stop_hit: + finish_reason = "stop" + elif finish_reason is None and finished: + finish_reason = "stop" + elif finish_reason is not None: + finished = True - if finished: - break + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=full_token_count or 0, + completion_tokens=token_count, + finished=finished, + finish_reason=finish_reason, + ) + + if finished: + break + finally: + if not producer_task.done(): + abort_event.set() + await producer_task if not finished: yield GenerationOutput( @@ -1336,3 +1748,10 @@ def get_cache_stats(self) -> dict[str, Any] | None: if self._is_mllm and self._model is not None: return self._model.get_cache_stats() return None + + def clear_runtime_caches(self) -> dict[str, Any] | None: + """Clear engine-managed runtime caches.""" + if self._is_mllm and self._model is not None: + self._model.clear_cache() + return {"model_cache": True} + return None diff --git a/vllm_mlx/engine_core.py b/vllm_mlx/engine_core.py index ae75fd39e..83610f3db 100644 --- a/vllm_mlx/engine_core.py +++ b/vllm_mlx/engine_core.py @@ -15,6 +15,7 @@ import logging import time import uuid +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from typing import Any, AsyncIterator, Dict, List, Optional, Union @@ -24,10 +25,17 @@ from .scheduler import Scheduler, SchedulerConfig from .output_collector import RequestOutputCollector, RequestStreamState from .model_registry import get_registry +from .mlx_streams import bind_generation_streams logger = logging.getLogger(__name__) +def _is_stream_thread_error(error: Exception) -> bool: + """True when MLX reports stream ownership mismatch across threads.""" + message = str(error) + return "no Stream(" in message or "no Stream(gpu" in message + + @dataclass class EngineConfig: """Configuration for the engine.""" @@ -125,6 +133,10 @@ async def stop(self) -> None: except asyncio.CancelledError: pass self._task = None + # Safety net: close batch generator if _engine_loop didn't get a + # chance to clean up (e.g. it was never started). The call is + # idempotent — _close_batch_generator checks for None. + self.scheduler._close_batch_generator() logger.info("Engine stopped") def is_running(self) -> bool: @@ -132,20 +144,83 @@ def is_running(self) -> bool: return self._running async def _engine_loop(self) -> None: - """Main engine loop - hybrid executor for prefill vs generation. + """Main engine loop. - Prefill steps (long prompts) are run in a thread executor to keep - the asyncio event loop responsive. Generation-only steps (~1-3ms) - are called directly to avoid ~0.5-2ms context switch overhead, - giving ~5-10% throughput improvement during sustained generation. + scheduler.step runs on one dedicated worker thread. MLX streams are + thread-local, so we rebind generation streams inside that worker. """ - import concurrent.futures - # Single-thread executor ensures MLX calls are never concurrent - _executor = concurrent.futures.ThreadPoolExecutor( - max_workers=1, thread_name_prefix="mlx-step" - ) loop = asyncio.get_running_loop() + worker = ThreadPoolExecutor(max_workers=1, thread_name_prefix="engine-core") + worker_stream_bound = False + model_thread_stream_bound = False + use_worker_thread = True + stream_thread_fallback_used = False + + def _bind_worker_streams_once() -> None: + nonlocal worker_stream_bound + if not worker_stream_bound: + bind_generation_streams() + worker_stream_bound = True + + def _bind_model_streams_once() -> None: + nonlocal model_thread_stream_bound + if not model_thread_stream_bound: + bind_generation_streams() + model_thread_stream_bound = True + + def _step_on_worker(): + _bind_worker_streams_once() + output = self.scheduler.step() + self._steps_executed += 1 + + if self._steps_executed % _memory_check_interval == 0: + try: + active_mem = mx.get_active_memory() + if active_mem > _memory_pressure_threshold: + mx.clear_cache() + logger.warning( + f"[Memory pressure] {active_mem / 1e9:.1f}GB > " + f"{_memory_pressure_threshold / 1e9:.0f}GB threshold, " + f"forced cache clear" + ) + except Exception: + pass + + return output + + def _step_on_model_thread(): + _bind_model_streams_once() + output = self.scheduler.step() + self._steps_executed += 1 + + if self._steps_executed % _memory_check_interval == 0: + try: + active_mem = mx.get_active_memory() + if active_mem > _memory_pressure_threshold: + mx.clear_cache() + logger.warning( + f"[Memory pressure] {active_mem / 1e9:.1f}GB > " + f"{_memory_pressure_threshold / 1e9:.0f}GB threshold, " + f"forced cache clear" + ) + except Exception: + pass + + return output + + def _recover_stream_thread_error_on_worker() -> None: + _bind_worker_streams_once() + self.scheduler._recover_from_cache_error() + self.scheduler._reschedule_running_requests() + + def _clear_cache_on_worker() -> None: + _bind_worker_streams_once() + mx.clear_cache() + + def _close_batch_generator_on_worker() -> None: + _bind_worker_streams_once() + self.scheduler._close_batch_generator() step_interval = self.config.step_interval stream_interval = self.config.stream_interval @@ -162,94 +237,101 @@ async def _engine_loop(self) -> None: _memory_pressure_threshold = 200 * 1024 * 1024 * 1024 _memory_check_interval = 64 - while self._running: - try: - if self.scheduler.has_requests(): - # Hybrid approach: use executor only when prefill is likely. - # Prefill happens when there are waiting requests that need - # to be inserted into the batch (may block for seconds). - # Generation-only steps are fast (<3ms) and can run inline. - has_waiting = self.scheduler.get_num_waiting() > 0 - has_partial = ( - self.scheduler.batch_generator is not None - and getattr(self.scheduler.batch_generator, "_partial", None) - is not None - ) - needs_executor = has_waiting or has_partial - - if needs_executor: - output = await loop.run_in_executor( - _executor, self.scheduler.step - ) - else: - output = self.scheduler.step() - # Yield to event loop after inline step - await asyncio.sleep(0) - self._steps_executed += 1 - - # Emergency memory pressure check - if self._steps_executed % _memory_check_interval == 0: - try: - active_mem = mx.get_active_memory() - if active_mem > _memory_pressure_threshold: - mx.clear_cache() - logger.warning( - f"[Memory pressure] {active_mem / 1e9:.1f}GB > " - f"{_memory_pressure_threshold / 1e9:.0f}GB threshold, " - f"forced cache clear" + try: + while self._running: + try: + if self.scheduler.has_requests(): + if use_worker_thread: + try: + output = await loop.run_in_executor( + worker, _step_on_worker ) - except Exception: - pass - - # Fast path: distribute outputs to collectors - outputs = output.outputs - if outputs: - collectors = self._output_collectors - states = self._stream_states - events = self._finished_events - - for req_output in outputs: - rid = req_output.request_id - collector = collectors.get(rid) - - if collector is not None: - # Optimized: skip stream_interval check when interval=1 - if use_simple_streaming: - collector.put(req_output) - else: - state = states.get(rid) - if state and state.should_send( - req_output.completion_tokens, - req_output.finished, - ): - collector.put(req_output) - state.mark_sent(req_output.completion_tokens) + except Exception as e: + if ( + _is_stream_thread_error(e) + and not stream_thread_fallback_used + ): + await loop.run_in_executor( + worker, _recover_stream_thread_error_on_worker + ) + use_worker_thread = False + stream_thread_fallback_used = True + _bind_model_streams_once() + logger.warning( + "Detected MLX stream/thread mismatch on worker " + "step; switched this engine to model-thread stepping" + ) + continue + raise + else: + output = _step_on_model_thread() + # Yield to event loop after each step. + await asyncio.sleep(0) - if req_output.finished: - event = events.get(rid) - if event: - event.set() + # Fast path: distribute outputs to collectors + outputs = output.outputs + if outputs: + collectors = self._output_collectors + states = self._stream_states + events = self._finished_events - # Free Metal buffers after distributing finished outputs - if output.finished_request_ids: - mx.clear_cache() + for req_output in outputs: + rid = req_output.request_id + collector = collectors.get(rid) - # Always yield to prevent event loop starvation. - # Without this, orphaned requests (client disconnected but - # request still in scheduler) block the entire event loop, - # making the server unresponsive to all HTTP requests. - await asyncio.sleep(0) - else: - # No work, yield control - await asyncio.sleep(step_interval) + if collector is not None: + # Optimized: skip stream_interval check when interval=1 + if use_simple_streaming: + collector.put(req_output) + else: + state = states.get(rid) + if state and state.should_send( + req_output.completion_tokens, + req_output.finished, + ): + collector.put(req_output) + state.mark_sent( + req_output.completion_tokens + ) + + if req_output.finished: + event = events.get(rid) + if event: + event.set() + + # Free Metal buffers after distributing finished outputs + if output.finished_request_ids: + if use_worker_thread: + await loop.run_in_executor( + worker, _clear_cache_on_worker + ) + else: + mx.clear_cache() - except asyncio.CancelledError: - break - except Exception as e: - import traceback + # Always yield to prevent event loop starvation. + # Without this, orphaned requests (client disconnected but + # request still in scheduler) block the entire event loop, + # making the server unresponsive to all HTTP requests. + await asyncio.sleep(0) + else: + # No work, yield control + await asyncio.sleep(step_interval) - logger.error(f"Engine loop error: {e}\n{traceback.format_exc()}") - await asyncio.sleep(0.1) + except asyncio.CancelledError: + raise + except Exception as e: + import traceback + + logger.error(f"Engine loop error: {e}\n{traceback.format_exc()}") + await asyncio.sleep(0.1) + finally: + try: + if use_worker_thread: + await loop.run_in_executor(worker, _close_batch_generator_on_worker) + else: + self.scheduler._close_batch_generator() + finally: + worker.shutdown(wait=True) async def add_request( self, @@ -368,16 +450,17 @@ async def stream_outputs( f"{_time.monotonic() - _t0:.1f}s" ) - yield output - if output.finished: finished_normally = True logger.info( f"[stream_outputs] {request_id[:12]} finished normally, " f"{_token_count} tokens in {_time.monotonic() - _t0:.1f}s" ) + yield output break + yield output + except asyncio.TimeoutError: logger.warning( f"[stream_outputs] {request_id[:12]} TIMEOUT after " @@ -547,6 +630,15 @@ def load_cache_from_disk(self, cache_dir: str) -> int: """Load prefix cache from disk.""" return self.scheduler.load_cache_from_disk(cache_dir) + def clear_runtime_caches(self) -> Dict[str, Any] | None: + """Clear scheduler-managed runtime caches.""" + return self.scheduler.clear_runtime_caches() + + def clear_prefix_cache(self) -> None: + """Clear the prefix cache (delegates to scheduler).""" + if hasattr(self.scheduler, "clear_prefix_cache"): + self.scheduler.clear_prefix_cache() + def _release_model(self) -> None: """Release model ownership.""" if self._owns_model and not self._closed: @@ -629,7 +721,7 @@ async def __aexit__(self, *args) -> None: def start(self) -> None: """Start engine (creates task in current loop).""" - asyncio.create_task(self.engine.start()) + self._start_task = asyncio.create_task(self.engine.start()) async def stop(self) -> None: """Stop the engine.""" @@ -691,3 +783,7 @@ def save_cache_to_disk(self, cache_dir: str) -> bool: def load_cache_from_disk(self, cache_dir: str) -> int: """Load prefix cache from disk.""" return self.engine.load_cache_from_disk(cache_dir) + + def clear_runtime_caches(self) -> Dict[str, Any] | None: + """Clear scheduler-managed runtime caches.""" + return self.engine.clear_runtime_caches() diff --git a/vllm_mlx/lifecycle.py b/vllm_mlx/lifecycle.py new file mode 100644 index 000000000..2e33d2802 --- /dev/null +++ b/vllm_mlx/lifecycle.py @@ -0,0 +1,485 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Model lifecycle / residency management for vllm-mlx.""" + +from __future__ import annotations + +import asyncio +import inspect +from collections.abc import Awaitable, Callable +from contextlib import suppress +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from .engine.base import BaseEngine, suspend_cancellation + + +class ResidentState(str, Enum): + """Runtime residency state for a configured model.""" + + UNLOADED = "unloaded" + LOADING = "loading" + LOADED = "loaded" + UNLOADING = "unloading" + FAILED = "failed" + + +@dataclass(frozen=True) +class ModelSpec: + """Immutable engine construction inputs for a resident model.""" + + model_key: str + model_name: str + use_batching: bool = False + scheduler_config: Any | None = None + stream_interval: int = 1 + max_tokens: int = 32768 + force_mllm: bool = False + mtp: bool = False + prefill_step_size: int = 2048 + specprefill_enabled: bool = False + specprefill_threshold: int = 8192 + specprefill_keep_pct: float = 0.3 + specprefill_draft_model: str | None = None + + +@dataclass +class ResidentModel: + """Runtime state for a single resident model.""" + + spec: ModelSpec + state: ResidentState = ResidentState.UNLOADED + engine: BaseEngine | None = None + active_requests: int = 0 + last_used_at: float | None = None + loaded_at: float | None = None + last_error: str | None = None + estimated_memory_bytes: int | None = None + _load_waiters: int = field(default=0, repr=False) + _load_waiter_task: asyncio.Task[BaseEngine] | None = field(default=None, repr=False) + _prepare_task: asyncio.Task[None] | None = field(default=None, repr=False) + _abandoned_loading_task: asyncio.Task[BaseEngine] | None = field( + default=None, repr=False + ) + _loading_task: asyncio.Task[BaseEngine] | None = field(default=None, repr=False) + _unloading_task: asyncio.Task[bool] | None = field(default=None, repr=False) + + +class ResidencyManager: + """Single-flight lifecycle manager for resident models.""" + + def __init__( + self, + engine_factory: Callable[[ModelSpec], Awaitable[BaseEngine]], + *, + on_engine_loaded: ( + Callable[[ModelSpec, BaseEngine], Awaitable[None] | None] | None + ) = None, + on_engine_unloading: ( + Callable[[ModelSpec, BaseEngine], Awaitable[None] | None] | None + ) = None, + time_fn: Callable[[], float] | None = None, + auto_unload_idle_seconds: float = 0, + ) -> None: + self._engine_factory = engine_factory + self._on_engine_loaded = on_engine_loaded + self._on_engine_unloading = on_engine_unloading + self._time_fn = time_fn or __import__("time").time + self.auto_unload_idle_seconds = auto_unload_idle_seconds + self._residents: dict[str, ResidentModel] = {} + self._lock = asyncio.Lock() + + def register_model(self, spec: ModelSpec) -> str: + """Register a model spec, or replace a dormant resident entry.""" + existing = self._residents.get(spec.model_key) + if existing is not None: + is_dormant = ( + existing.engine is None + and existing.active_requests == 0 + and existing._load_waiters == 0 + and existing._loading_task is None + and existing._unloading_task is None + and existing.state in {ResidentState.UNLOADED, ResidentState.FAILED} + ) + if not is_dormant: + raise RuntimeError( + f"Cannot replace resident model '{spec.model_key}' while it is live" + ) + + self._residents[spec.model_key] = ResidentModel(spec=spec) + return spec.model_key + + def get_engine(self, model_key: str) -> BaseEngine | None: + """Get the currently loaded engine, if any.""" + return self._resident(model_key).engine + + def get_status(self, model_key: str) -> dict[str, Any]: + """Return a serializable snapshot of resident state.""" + resident = self._resident(model_key) + return { + "model_key": resident.spec.model_key, + "model_name": resident.spec.model_name, + "state": resident.state.value, + "active_requests": resident.active_requests, + "last_used_at": resident.last_used_at, + "loaded_at": resident.loaded_at, + "last_error": resident.last_error, + "estimated_memory_bytes": resident.estimated_memory_bytes, + "auto_unload_idle_seconds": self.auto_unload_idle_seconds, + } + + async def ensure_loaded(self, model_key: str) -> BaseEngine: + """Load and start a resident engine if needed.""" + while True: + task: asyncio.Task[BaseEngine] | None = None + unloading_task: asyncio.Task[bool] | None = None + + async with self._lock: + resident = self._resident(model_key) + if ( + resident.state == ResidentState.LOADED + and resident.engine is not None + ): + return resident.engine + + if resident._unloading_task is not None: + unloading_task = resident._unloading_task + else: + if resident._loading_task is None: + resident.state = ResidentState.LOADING + resident.last_error = None + resident._loading_task = asyncio.create_task( + self._load_engine(resident) + ) + resident._load_waiters = 0 + resident._load_waiter_task = resident._loading_task + resident._abandoned_loading_task = None + task = resident._loading_task + resident._load_waiters += 1 + resident._load_waiter_task = task + + if unloading_task is not None: + await asyncio.shield(unloading_task) + continue + + if task is None: + raise RuntimeError(f"No load task available for resident {model_key}") + try: + return await asyncio.shield(task) + except asyncio.CancelledError: + current_task = asyncio.current_task() + cancelling = getattr(current_task, "cancelling", None) + if ( + task.done() + and task.cancelled() + and (cancelling is None or cancelling() == 0) + ): + async with self._lock: + resident = self._resident(model_key) + if resident._abandoned_loading_task is task: + continue + raise + finally: + await self._release_load_waiter(model_key, task) + + async def acquire( + self, + model_key: str, + *, + count_activity: bool = True, + ) -> BaseEngine: + """Acquire a resident engine for request processing.""" + while True: + engine = await self.ensure_loaded(model_key) + async with self._lock: + resident = self._resident(model_key) + if ( + resident.engine is not engine + or resident.state != ResidentState.LOADED + or resident._unloading_task is not None + ): + continue + resident.active_requests += 1 + if count_activity: + resident.last_used_at = self._time_fn() + return engine + + async def release(self, model_key: str, *, count_activity: bool = True) -> None: + """Release a previously acquired resident engine.""" + async with self._lock: + resident = self._resident(model_key) + if resident.active_requests > 0: + resident.active_requests -= 1 + if count_activity: + resident.last_used_at = self._time_fn() + + async def unload_if_idle(self, model_key: str) -> bool: + """Unload a resident engine if it has been idle past the threshold.""" + if self.auto_unload_idle_seconds <= 0: + return False + + while True: + unloading_task: asyncio.Task[bool] | None = None + async with self._lock: + resident = self._resident(model_key) + + if resident._loading_task is not None: + return False + + if resident._unloading_task is not None: + unloading_task = resident._unloading_task + else: + if ( + resident.state != ResidentState.LOADED + or resident.engine is None + or resident.active_requests > 0 + or resident.last_used_at is None + ): + return False + + idle_for = self._time_fn() - resident.last_used_at + if idle_for < self.auto_unload_idle_seconds: + return False + + resident.state = ResidentState.UNLOADING + resident._unloading_task = asyncio.create_task( + self._unload_engine(resident) + ) + unloading_task = resident._unloading_task + + if unloading_task is None: + return False + return await asyncio.shield(unloading_task) + + async def shutdown(self) -> None: + """Stop all loaded residents.""" + keys = list(self._residents.keys()) + failures: list[str] = [] + for model_key in keys: + while True: + loading_task: asyncio.Task[BaseEngine] | None = None + unloading_task: asyncio.Task[bool] | None = None + + async with self._lock: + resident = self._resident(model_key) + + if resident._loading_task is not None: + resident._loading_task.cancel() + loading_task = resident._loading_task + elif ( + resident.engine is None + or resident.state == ResidentState.UNLOADED + ): + break + else: + if resident._unloading_task is None: + resident.state = ResidentState.UNLOADING + resident._unloading_task = asyncio.create_task( + self._unload_engine(resident) + ) + unloading_task = resident._unloading_task + + if loading_task is not None: + with suppress(asyncio.CancelledError): + await loading_task + continue + + if unloading_task is not None: + # Shield the unload so that cancelling shutdown() does not + # orphan a half-stopped engine in UNLOADING state. + try: + unloaded = await asyncio.shield(unloading_task) + except asyncio.CancelledError: + # Shutdown itself was cancelled — finish the in-flight + # unload deterministically before propagating. + with suspend_cancellation(): + unloaded = await unloading_task + raise + if not unloaded: + async with self._lock: + resident = self._resident(model_key) + error = resident.last_error or "resident remained loaded" + failures.append( + f"Failed to unload resident model '{model_key}' during shutdown: {error}" + ) + break + break + + if failures: + if len(failures) == 1: + raise RuntimeError(failures[0]) + raise RuntimeError("; ".join(failures)) + + async def _load_engine(self, resident: ResidentModel) -> BaseEngine: + """Create and start a resident engine.""" + engine: BaseEngine | None = None + try: + engine = await self._engine_factory(resident.spec) + await self._prepare_engine_start(resident, engine) + await engine.start() + await self._run_hook(self._on_engine_loaded, resident.spec, engine) + except asyncio.CancelledError: + await self._cleanup_cancelled_load(resident, engine) + raise + except Exception as exc: + async with self._lock: + abandoned = resident._abandoned_loading_task is asyncio.current_task() + if abandoned: + await self._cleanup_cancelled_load(resident, engine) + raise asyncio.CancelledError() from exc + if engine is not None: + with suppress(Exception): + await engine.stop() + async with self._lock: + resident.state = ResidentState.FAILED + resident.last_error = str(exc) + resident._abandoned_loading_task = None + resident._loading_task = None + raise + + try: + async with self._lock: + resident.engine = engine + resident.state = ResidentState.LOADED + resident.loaded_at = self._time_fn() + resident.last_used_at = resident.loaded_at + resident.last_error = None + resident._abandoned_loading_task = None + resident._loading_task = None + except asyncio.CancelledError: + await self._cleanup_cancelled_load(resident, engine) + raise + + return engine + + async def _unload_engine(self, resident: ResidentModel) -> bool: + """Stop and drop a resident engine.""" + engine = resident.engine + if engine is None: + async with self._lock: + resident.state = ResidentState.UNLOADED + resident._unloading_task = None + return False + + try: + await self._run_hook(self._on_engine_unloading, resident.spec, engine) + await engine.stop() + except asyncio.CancelledError: + async with self._lock: + resident.state = ResidentState.LOADED + resident._unloading_task = None + raise + except Exception as exc: + async with self._lock: + resident.engine = engine + resident.state = ResidentState.LOADED + resident.last_error = str(exc) + resident._unloading_task = None + return False + + async with self._lock: + resident.engine = None + resident.state = ResidentState.UNLOADED + resident.loaded_at = None + resident.last_error = None + resident._unloading_task = None + + return True + + def _resident(self, model_key: str) -> ResidentModel: + try: + return self._residents[model_key] + except KeyError as exc: + raise KeyError(f"Resident model '{model_key}' is not registered") from exc + + async def _run_hook( + self, + hook: Callable[[ModelSpec, BaseEngine], Awaitable[None] | None] | None, + spec: ModelSpec, + engine: BaseEngine, + ) -> None: + if hook is None: + return + + result = hook(spec, engine) + if inspect.isawaitable(result): + await result + + async def _prepare_engine_start( + self, + resident: ResidentModel, + engine: BaseEngine, + ) -> None: + """Run blocking startup work away from the serving event loop.""" + prepare_for_start = getattr(engine, "prepare_for_start", None) + if prepare_for_start is None: + return + + prepare_task = asyncio.create_task(asyncio.to_thread(prepare_for_start)) + async with self._lock: + resident._prepare_task = prepare_task + + try: + await asyncio.shield(prepare_task) + except asyncio.CancelledError: + with suspend_cancellation(): + while not prepare_task.done(): + try: + await asyncio.shield(prepare_task) + except asyncio.CancelledError: + continue + except Exception: + break + raise + finally: + async with self._lock: + if resident._prepare_task is prepare_task: + resident._prepare_task = None + + async def _cleanup_cancelled_load( + self, + resident: ResidentModel, + engine: BaseEngine | None, + ) -> None: + """Stop a partially loaded engine and unwind resident state.""" + with suspend_cancellation(): + if engine is not None: + with suppress(Exception): + await engine.stop() + async with self._lock: + resident.engine = None + resident.state = ResidentState.UNLOADED + resident.loaded_at = None + resident.last_error = None + # Keep the abandoned-load marker until a new load task replaces it + # so late waiters on the old task can still recognize a retryable + # cancellation instead of inheriting CancelledError. + resident._loading_task = None + + async def _release_load_waiter( + self, + model_key: str, + task: asyncio.Task[BaseEngine], + ) -> None: + """Drop one waiter from a shared load, canceling abandoned solo loads.""" + task_to_cancel: asyncio.Task[BaseEngine] | None = None + + async with self._lock: + resident = self._resident(model_key) + if resident._load_waiter_task is not task or resident._load_waiters <= 0: + return + + resident._load_waiters -= 1 + if resident._load_waiters == 0: + resident._load_waiter_task = None + if resident._loading_task is task and not task.done(): + resident._abandoned_loading_task = task + task_to_cancel = task + + if task_to_cancel is None: + return + + with suspend_cancellation(): + task_to_cancel.cancel() + with suppress(asyncio.CancelledError): + await task_to_cancel diff --git a/vllm_mlx/mcp/config.py b/vllm_mlx/mcp/config.py index b96b81da4..04c8b629c 100644 --- a/vllm_mlx/mcp/config.py +++ b/vllm_mlx/mcp/config.py @@ -146,10 +146,20 @@ def validate_config(data: Dict[str, Any]) -> MCPConfig: if not isinstance(default_timeout, (int, float)) or default_timeout <= 0: raise ValueError("'default_timeout' must be a positive number") + allowed_high_risk_tools = data.get("allowed_high_risk_tools", []) + if not isinstance(allowed_high_risk_tools, list) or any( + not isinstance(tool, str) or not tool.strip() + for tool in allowed_high_risk_tools + ): + raise ValueError( + "'allowed_high_risk_tools' must be a list of non-empty strings" + ) + return MCPConfig( servers=servers, max_tool_calls=max_tool_calls, default_timeout=default_timeout, + allowed_high_risk_tools=set(allowed_high_risk_tools), ) @@ -184,5 +194,6 @@ def create_example_config() -> str: }, "max_tool_calls": 10, "default_timeout": 30.0, + "allowed_high_risk_tools": [], } return json.dumps(example, indent=2) diff --git a/vllm_mlx/mcp/security.py b/vllm_mlx/mcp/security.py index daef67341..653de6bb4 100644 --- a/vllm_mlx/mcp/security.py +++ b/vllm_mlx/mcp/security.py @@ -8,9 +8,11 @@ import logging import os +import posixpath import re import shutil import time +from urllib.parse import unquote, urlparse from collections import defaultdict from dataclasses import dataclass from pathlib import Path @@ -77,6 +79,29 @@ re.compile(r"<\s*/"), # Read from absolute path ] +# Explicit inline-code execution forms for interpreter-like commands. These +# combinations turn an otherwise whitelisted runtime into a raw code-execution +# primitive and are not needed for normal MCP server launches. +BLOCKED_COMMAND_ARG_RULES: Dict[str, Dict[str, str]] = { + "python": { + "-c": "inline Python execution", + }, + "python3": { + "-c": "inline Python execution", + }, + "node": { + "-e": "inline JavaScript evaluation", + "--eval": "inline JavaScript evaluation", + "-p": "JavaScript evaluation/print", + "--print": "JavaScript evaluation/print", + }, + "npx": { + "-c": "shell command execution", + "--call": "shell command execution", + }, +} +CONTROL_CHARS = ("\n", "\r") + class MCPSecurityError(Exception): """Raised when MCP security validation fails.""" @@ -123,6 +148,51 @@ def __init__( "This should NEVER be used in production!" ) + def _check_control_chars(self, value: str, context: str, server_name: str) -> None: + """Block command separators carried via literal newlines.""" + if any(ch in value for ch in CONTROL_CHARS): + raise MCPSecurityError( + f"MCP server '{server_name}': {context} contains newline characters. " + "Potential command injection blocked." + ) + + def _check_path_traversal(self, value: str, context: str, server_name: str) -> None: + """ + Block parent-directory traversal, including URL-encoded forms. + + This normalizes likely path-like inputs rather than relying only on + the simple ``../`` regex, which can be bypassed by percent-encoding. + """ + candidates = [value] + decoded = unquote(value) + if decoded != value: + candidates.append(decoded) + + for candidate in candidates: + if ( + "/" not in candidate + and "\\" not in candidate + and "%2e" not in value.lower() + ): + continue + + normalized = posixpath.normpath(candidate.replace("\\", "/")) + if normalized == ".." or normalized.startswith("../"): + raise MCPSecurityError( + f"MCP server '{server_name}': {context} contains path traversal: " + f"'{value}'." + ) + + # Also reject any explicit parent segments before or after normalization. + path_parts = [ + part for part in candidate.replace("\\", "/").split("/") if part + ] + if any(part == ".." for part in path_parts): + raise MCPSecurityError( + f"MCP server '{server_name}': {context} contains path traversal: " + f"'{value}'." + ) + def validate_command(self, command: str, server_name: str) -> None: """ Validate that a command is safe to execute. @@ -141,6 +211,9 @@ def validate_command(self, command: str, server_name: str) -> None: ) return + self._check_control_chars(command, "Command", server_name) + self._check_path_traversal(command, "Command", server_name) + # Check for dangerous patterns in command for pattern in DANGEROUS_PATTERNS: if pattern.search(command): @@ -199,6 +272,8 @@ def validate_args(self, args: List[str], server_name: str) -> None: return for i, arg in enumerate(args): + self._check_control_chars(arg, f"Argument {i}", server_name) + self._check_path_traversal(arg, f"Argument {i}", server_name) for pattern in DANGEROUS_ARG_PATTERNS: if pattern.search(arg): raise MCPSecurityError( @@ -210,6 +285,50 @@ def validate_args(self, args: List[str], server_name: str) -> None: f"MCP server '{server_name}': {len(args)} arguments validated successfully" ) + def validate_command_args( + self, + command: str, + args: List[str], + server_name: str, + ) -> None: + """ + Validate command-specific argument combinations. + + Some whitelisted runtimes (python, node, npx) remain acceptable for + launching packaged MCP servers, but inline evaluator flags such as + ``python -c`` and ``node -e`` must be rejected. + """ + if self.allow_unsafe or not args: + return + + base_command = Path(command).name + blocked_rules = BLOCKED_COMMAND_ARG_RULES.get(base_command) + if not blocked_rules: + return + + for i, arg in enumerate(args): + if arg in blocked_rules: + raise MCPSecurityError( + f"MCP server '{server_name}': Argument {i} '{arg}' enables " + f"{blocked_rules[arg]} for '{base_command}', which is not allowed." + ) + + if base_command == "node" and arg.startswith("--eval="): + raise MCPSecurityError( + f"MCP server '{server_name}': Argument {i} '{arg}' enables " + "inline JavaScript evaluation for 'node', which is not allowed." + ) + + if base_command == "npx" and arg.startswith("--call="): + raise MCPSecurityError( + f"MCP server '{server_name}': Argument {i} '{arg}' enables " + "shell command execution for 'npx', which is not allowed." + ) + + logger.debug( + f"MCP server '{server_name}': command-specific arguments validated" + ) + def validate_env(self, env: Optional[Dict[str, str]], server_name: str) -> None: """ Validate environment variables for dangerous values. @@ -236,6 +355,14 @@ def validate_env(self, env: Optional[Dict[str, str]], server_name: str) -> None: } for key, value in env.items(): + self._check_control_chars( + value, f"Environment variable '{key}'", server_name + ) + self._check_path_traversal( + value, + f"Environment variable '{key}'", + server_name, + ) # Check for dangerous env var names if key.upper() in dangerous_env_vars: raise MCPSecurityError( @@ -269,6 +396,8 @@ def validate_url(self, url: str, server_name: str) -> None: if self.allow_unsafe: return + self._check_control_chars(url, "URL", server_name) + # Must be http or https if not url.startswith(("http://", "https://")): raise MCPSecurityError( @@ -283,6 +412,11 @@ def validate_url(self, url: str, server_name: str) -> None: f"Consider using HTTPS for production environments." ) + parsed = urlparse(url) + self._check_path_traversal(parsed.path, "URL", server_name) + if parsed.query: + self._check_control_chars(parsed.query, "URL query", server_name) + # Check for dangerous patterns for pattern in DANGEROUS_PATTERNS: if pattern.search(url): @@ -342,6 +476,8 @@ def validate_mcp_server_config( if args: validator.validate_args(args, server_name) + if command: + validator.validate_command_args(command, args, server_name) if env: validator.validate_env(env, server_name) @@ -404,6 +540,7 @@ def __init__( self, allowed_tools: Optional[Set[str]] = None, blocked_tools: Optional[Set[str]] = None, + allowed_high_risk_tools: Optional[Set[str]] = None, blocked_arg_patterns: Optional[List[re.Pattern]] = None, max_calls_per_minute: int = 60, audit_callback: Optional[Callable[[ToolExecutionAudit], None]] = None, @@ -415,6 +552,7 @@ def __init__( Args: allowed_tools: If set, only these tools can be executed (whitelist mode). blocked_tools: Tools that are always blocked (blacklist mode). + allowed_high_risk_tools: High-risk tools that are explicitly allowed. blocked_arg_patterns: Patterns to block in tool arguments. max_calls_per_minute: Rate limit for tool calls (0 = unlimited). audit_callback: Optional callback for audit events. @@ -422,6 +560,9 @@ def __init__( """ self.allowed_tools = allowed_tools self.blocked_tools = blocked_tools or set() + self.allowed_high_risk_tools = { + tool.lower() for tool in (allowed_high_risk_tools or set()) + } self.blocked_arg_patterns = ( blocked_arg_patterns or DANGEROUS_TOOL_ARG_PATTERNS.copy() ) @@ -482,7 +623,7 @@ def validate_tool_execution( ) # Check for high-risk tool patterns - self._check_high_risk_tool(tool_name) + self._check_high_risk_tool(tool_name, full_name) # Validate arguments self._validate_arguments(tool_name, arguments) @@ -500,16 +641,26 @@ def _is_blocked(self, tool_name: str, full_name: str) -> bool: or tool_name.lower() in self.blocked_tools ) - def _check_high_risk_tool(self, tool_name: str) -> None: + def _check_high_risk_tool(self, tool_name: str, full_name: str) -> None: """Check if tool matches high-risk patterns.""" tool_lower = tool_name.lower() + full_lower = full_name.lower() for pattern in HIGH_RISK_TOOL_PATTERNS: if pattern in tool_lower: - logger.warning( - f"High-risk tool detected: '{tool_name}' matches pattern '{pattern}'. " - f"Ensure this tool is from a trusted MCP server." + if ( + tool_lower in self.allowed_high_risk_tools + or full_lower in self.allowed_high_risk_tools + ): + logger.warning( + "Allowing high-risk tool '%s' due to explicit allowlist entry", + full_name, + ) + return + raise MCPSecurityError( + f"High-risk tool '{tool_name}' is blocked by security policy. " + f"Add '{full_name}' or '{tool_name}' to allowed_high_risk_tools " + f"to allow it explicitly." ) - break def _validate_arguments(self, tool_name: str, arguments: Dict[str, Any]) -> None: """Validate tool arguments for dangerous patterns.""" diff --git a/vllm_mlx/mcp/types.py b/vllm_mlx/mcp/types.py index 00830e569..e990e78bd 100644 --- a/vllm_mlx/mcp/types.py +++ b/vllm_mlx/mcp/types.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set class MCPTransport(str, Enum): @@ -85,6 +85,7 @@ class MCPConfig: servers: Dict[str, MCPServerConfig] = field(default_factory=dict) max_tool_calls: int = 10 default_timeout: float = 30.0 + allowed_high_risk_tools: Set[str] = field(default_factory=set) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MCPConfig": @@ -98,6 +99,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "MCPConfig": servers=servers, max_tool_calls=data.get("max_tool_calls", 10), default_timeout=data.get("default_timeout", 30.0), + allowed_high_risk_tools=set(data.get("allowed_high_risk_tools", [])), ) diff --git a/vllm_mlx/memory_cache.py b/vllm_mlx/memory_cache.py index f68192f4a..397e845b6 100644 --- a/vllm_mlx/memory_cache.py +++ b/vllm_mlx/memory_cache.py @@ -38,6 +38,9 @@ _DEFAULT_MEMORY_PERCENT = 0.20 # 20% of available RAM _MIN_MEMORY_BYTES = 100 * _BYTES_PER_MB # Minimum 100MB _MAX_ENTRIES_FALLBACK = 50 # Fallback if memory detection fails +# Bump this when the cache on-disk format or KV semantics change. +# Loading a cache with a different version is rejected automatically. +_CACHE_PERSIST_VERSION = 3 def _get_available_memory() -> int: @@ -59,6 +62,44 @@ def _get_available_memory() -> int: return 0 +_TURBOQUANT_UPSTREAM_PR = "https://github.com/ml-explore/mlx-lm/pull/1067" + + +def _check_turboquant_capability() -> str | None: + """Verify the installed mlx-lm exposes TurboQuant KV cache support. + + Returns None if capability is available, otherwise a human-readable + description of what is missing. Callers can use the return value to + fail fast with an actionable error message. + + TurboQuant support requires two pieces, both provided by the upstream + mlx-lm PR 1067: + 1. the ``mlx_lm.models.turboquant_cache`` module (the cache class), and + 2. the ``to_turbo_quantized`` method bolted onto ``KVCache`` (the entry + point used to compress stored prefix-cache entries). + """ + try: + from mlx_lm.models.turboquant_cache import TurboQuantKVCache # noqa: F401 + except ImportError: + return ( + "mlx_lm.models.turboquant_cache is not available. " + f"Install an mlx-lm build with TurboQuant support (see {_TURBOQUANT_UPSTREAM_PR})." + ) + + try: + from mlx_lm.models.cache import KVCache + except ImportError as e: + return f"mlx_lm.models.cache.KVCache is not importable: {e}" + + if not hasattr(KVCache, "to_turbo_quantized"): + return ( + "mlx_lm.models.cache.KVCache lacks to_turbo_quantized(). " + f"Upgrade to an mlx-lm build with TurboQuant support (see {_TURBOQUANT_UPSTREAM_PR})." + ) + + return None + + def _array_memory(arr) -> int: """ Estimate array memory from shape+dtype without triggering lazy eval. @@ -103,7 +144,37 @@ def estimate_kv_cache_memory(cache: list[Any]) -> int: total_bytes = 0 + try: + from mlx_lm.models.turboquant_cache import TurboQuantKVCache as _TQ + except ImportError: + _TQ = None + + # Two TurboQuant paths exist because stored prefix entries use + # _TurboQuantCacheWrapper (carries orig_type/orig_attrs metadata), while + # ad-hoc callers can pass a bare TurboQuantKVCache (e.g. tests, direct API). + # The wrapper path iterates .state (a flat tuple of arrays exposed by the + # upstream class); the bare path reads the packed arrays directly so we do + # not rely on the public .state tuple layout for unwrapped instances. for layer_cache in cache: + if ( + _TQ is not None + and hasattr(layer_cache, "layer") + and isinstance(layer_cache.layer, _TQ) + ): + for arr in layer_cache.layer.state: + total_bytes += _array_memory(arr) + continue + if _TQ is not None and isinstance(layer_cache, _TQ): + if layer_cache.k_packed is not None: + for arr in ( + layer_cache.k_packed, + layer_cache.v_packed, + layer_cache.k_norms, + layer_cache.v_norms, + ): + if arr is not None: + total_bytes += _array_memory(arr) + continue # Handle different cache object types # Check dict first since dicts have .keys() method that would match below if isinstance(layer_cache, dict) and "state" in layer_cache: @@ -165,6 +236,7 @@ class MemoryCacheConfig: kv_bits: int = 8 kv_group_size: int = 64 kv_min_quantize_tokens: int = 256 + turbo_kv_bits: int | None = None # TurboQuant: 1-4 bit, 4.6x compression at 3-bit def __post_init__(self) -> None: if not 0.0 < self.max_memory_percent <= 1.0: @@ -177,6 +249,29 @@ def __post_init__(self) -> None: raise ValueError( f"kv_min_quantize_tokens must be >= 0, got {self.kv_min_quantize_tokens}" ) + if self.turbo_kv_bits is not None: + if self.turbo_kv_bits not in (1, 2, 3, 4): + raise ValueError(f"turbo_kv_bits must be 1-4, got {self.turbo_kv_bits}") + if self.kv_quantize: + # TurboQuant and standard group-wise quantization are two + # different storage formats for prefix-cache entries; only one + # path runs in _apply_quantization. Rejecting the combination + # avoids masking user intent (which one did they actually want?). + raise ValueError( + "turbo_kv_bits and kv_quantize are mutually exclusive; " + "pick one compression path" + ) + missing = _check_turboquant_capability() + if missing is not None: + raise RuntimeError( + f"turbo_kv_bits={self.turbo_kv_bits} requested but TurboQuant " + f"support is unavailable: {missing}" + ) + + @property + def needs_dequantize(self) -> bool: + """Whether stored caches need dequantization on fetch.""" + return self.kv_quantize or self.turbo_kv_bits is not None def compute_memory_limit(self) -> int: """ @@ -254,6 +349,15 @@ def create(cls, tokens: list[int], cache: list[Any]) -> _CacheEntry: ) +def _capture_cache_attrs(layer: Any) -> dict[str, Any]: + """Capture cache-type metadata needed to reconstruct a layer.""" + attrs = {} + for attr in ("max_size", "keep", "step", "_idx"): + if hasattr(layer, attr): + attrs[attr] = getattr(layer, attr) + return attrs + + def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]: """Create copies of cache layers with the last ``trim_by`` positions removed. @@ -267,11 +371,17 @@ def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]: For RotatingKVCache: actually trims the circular buffer — reducing offset alone breaks ``size()`` / ``_temporal_order`` invariants. - Supports KVCache, RotatingKVCache, and _QuantizedCacheWrapper. + Supports KVCache, RotatingKVCache, _QuantizedCacheWrapper, and + _TurboQuantCacheWrapper. """ import mlx.core as mx from mlx_lm.models.cache import RotatingKVCache + try: + from mlx_lm.models.turboquant_cache import TurboQuantKVCache + except ImportError: + TurboQuantKVCache = None # noqa: N806 + trimmed: list[Any] = [] eval_targets: list[Any] = [] for layer_cache in cache: @@ -286,6 +396,24 @@ def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]: tc.orig_type = layer_cache.orig_type tc.orig_attrs = layer_cache.orig_attrs trimmed.append(tc) + elif isinstance(layer_cache, _TurboQuantCacheWrapper): + if not layer_cache.is_trimmable(): + trimmed.append(layer_cache) + continue + # We rely on upstream TurboQuantKVCache.copy() returning a copy + # whose offset can be mutated independently of the stored layer. + # The packed k/v arrays are shared (MLX arrays are immutable, so + # sharing is safe), but mutable scalar state like offset and the + # dequant buffers must be per-copy. If upstream ever changes copy() + # semantics to share these too, this will corrupt stored entries. + tc = _TurboQuantCacheWrapper.__new__(_TurboQuantCacheWrapper) + tc.layer = layer_cache.layer.copy() + tc.layer.offset = max(layer_cache.layer.offset - trim_by, 0) + tc.offset = tc.layer.offset + tc.bits = layer_cache.bits + tc.orig_type = layer_cache.orig_type + tc.orig_attrs = layer_cache.orig_attrs + trimmed.append(tc) elif isinstance(layer_cache, RotatingKVCache): if layer_cache.keys is None or trim_by <= 0: trimmed.append(layer_cache) @@ -379,6 +507,15 @@ def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]: tc.values = layer_cache.values tc._idx = layer_cache._idx trimmed.append(tc) + elif TurboQuantKVCache is not None and isinstance( + layer_cache, TurboQuantKVCache + ): + if not hasattr(layer_cache, "copy"): + trimmed.append(layer_cache) + continue + tc = layer_cache.copy() + tc.offset = max(layer_cache.offset - trim_by, 0) + trimmed.append(tc) elif ( hasattr(layer_cache, "offset") and hasattr(layer_cache, "keys") @@ -386,9 +523,27 @@ def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]: ): orig_cls = type(layer_cache) tc = orig_cls.__new__(orig_cls) - tc.keys = layer_cache.keys - tc.values = layer_cache.values - tc.offset = max(layer_cache.offset - trim_by, 0) + new_offset = max(layer_cache.offset - trim_by, 0) + keys = layer_cache.keys + values = layer_cache.values + # Slice the arrays down to new_offset rather than just shrinking the + # offset pointer. Sharing the original (over-sized) array across + # requests lets attention paths that read the full underlying + # buffer (e.g. Gemma 4's KV-shared layers, which read cache.state + # directly instead of going through update_and_fetch) see stale + # tokens from the previous owner — issue #384. + if ( + keys is not None + and hasattr(keys, "shape") + and len(keys.shape) >= 3 + and new_offset < keys.shape[-2] + ): + tc.keys = keys[..., :new_offset, :] + tc.values = values[..., :new_offset, :] + else: + tc.keys = keys + tc.values = values + tc.offset = new_offset # Preserve type-specific attrs (max_size, keep, step, _idx) for attr in ("max_size", "keep", "step", "_idx"): if hasattr(layer_cache, attr): @@ -487,10 +642,41 @@ def __init__(self, layer: Any, bits: int, group_size: int): self.group_size = group_size self.orig_type = type(layer) # Preserve RotatingKVCache-specific attrs - self.orig_attrs = {} - for attr in ("max_size", "keep", "step", "_idx"): - if hasattr(layer, attr): - self.orig_attrs[attr] = getattr(layer, attr) + self.orig_attrs = _capture_cache_attrs(layer) + + +_TURBOQUANT_TRIM_WARNED = False + + +class _TurboQuantCacheWrapper: + """Lightweight wrapper storing TurboQuant cache + original cache metadata.""" + + __slots__ = ("layer", "offset", "bits", "orig_type", "orig_attrs") + + def __init__(self, layer: Any, bits: int): + self.layer = layer.to_turbo_quantized(bits=bits) + self.offset = self.layer.offset + self.bits = bits + self.orig_type = type(layer) + self.orig_attrs = _capture_cache_attrs(layer) + + def is_trimmable(self) -> bool: + # Trimming requires a public copy() on the underlying TurboQuant cache + # so we can adjust offset without mutating the stored entry. Older + # mlx-lm builds may ship TurboQuantKVCache without copy(); in that case + # the scheduler treats this layer as non-trimmable and falls back to + # full-prefix matching (correct, just less efficient). + trimmable = hasattr(self.layer, "copy") + if not trimmable: + global _TURBOQUANT_TRIM_WARNED + if not _TURBOQUANT_TRIM_WARNED: + logger.warning( + "TurboQuant cache layer has no copy() method; prefix cache " + "trimming (supersequence / LCP reuse) is disabled. " + f"Upgrade mlx-lm for full trim support (see {_TURBOQUANT_UPSTREAM_PR})." + ) + _TURBOQUANT_TRIM_WARNED = True + return trimmable def _quantize_cache(cache: list[Any], bits: int = 8, group_size: int = 64) -> list[Any]: @@ -512,6 +698,27 @@ def _quantize_cache(cache: list[Any], bits: int = 8, group_size: int = 64) -> li return quantized +def _turbo_quantize_cache(cache: list[Any], bits: int = 3) -> list[Any]: + """Compress KVCache layers with TurboQuant (4.6x at 3-bit). + + Uses PolarQuant: randomized Hadamard rotation + Lloyd-Max codebook + quantization with fused Metal kernels. See arXiv 2504.19874. + """ + from mlx_lm.models.cache import KVCache + + compressed = [] + for layer in cache: + if ( + type(layer) is KVCache + and layer.keys is not None + and hasattr(layer, "to_turbo_quantized") + ): + compressed.append(_TurboQuantCacheWrapper(layer, bits)) + else: + compressed.append(layer) + return compressed + + def _dequantize_cache(cache: list[Any]) -> list[Any]: """Dequantize _QuantizedCacheWrapper layers and copy non-quantized layers. @@ -520,6 +727,11 @@ def _dequantize_cache(cache: list[Any]) -> list[Any]: """ import mlx.core as mx + try: + from mlx_lm.models.turboquant_cache import TurboQuantKVCache + except ImportError: + TurboQuantKVCache = None # noqa: N806 + result = [] for layer in cache: if isinstance(layer, _QuantizedCacheWrapper): @@ -533,10 +745,40 @@ def _dequantize_cache(cache: list[Any]) -> list[Any]: *layer.values, group_size=layer.group_size, bits=layer.bits ) kv.offset = layer.offset + # Slice the dequantized arrays down to offset so that readers + # which bypass offset (e.g. Gemma 4 KV-shared layers reading + # cache.state directly) cannot see stale tokens from a previous + # request. Mirrors the plain-KVCache slice in + # _trim_cache_offset — see issue #384. + if ( + kv.keys is not None + and hasattr(kv.keys, "shape") + and len(kv.keys.shape) >= 3 + and kv.offset < kv.keys.shape[-2] + ): + kv.keys = kv.keys[..., : kv.offset, :] + kv.values = kv.values[..., : kv.offset, :] # Restore type-specific attrs (max_size, keep, step, _idx) for attr, val in layer.orig_attrs.items(): setattr(kv, attr, val) result.append(kv) + elif isinstance(layer, _TurboQuantCacheWrapper): + if not hasattr(layer.layer, "dequantize"): + result.append(layer) + continue + dequantized = layer.layer.dequantize() + orig_cls = layer.orig_type + kv = orig_cls.__new__(orig_cls) + kv.keys = ( + mx.array(dequantized.keys) if dequantized.keys is not None else None + ) + kv.values = ( + mx.array(dequantized.values) if dequantized.values is not None else None + ) + kv.offset = layer.offset + for attr, val in layer.orig_attrs.items(): + setattr(kv, attr, val) + result.append(kv) elif hasattr(layer, "keys") and hasattr(layer, "offset"): # Deep-copy non-quantized cache layers (e.g. RotatingKVCache) # so model's in-place mutations don't corrupt stored entries @@ -549,11 +791,55 @@ def _dequantize_cache(cache: list[Any]) -> list[Any]: if hasattr(layer, attr): setattr(kv, attr, getattr(layer, attr)) result.append(kv) + elif TurboQuantKVCache is not None and isinstance(layer, TurboQuantKVCache): + if hasattr(layer, "empty") and layer.empty(): + result.append(layer) + elif hasattr(layer, "dequantize"): + result.append(layer.dequantize()) + else: + result.append(layer) else: result.append(layer) return result +def _compute_model_fingerprint(model: Any) -> str: + """Compute a fingerprint from model architecture for cache compatibility. + + Used to reject disk-persisted caches created by a different model or + a different quantisation of the same model. The fingerprint is a + short hex digest of (num_layers, hidden_size, vocab_size, num_kv_heads, + head_dim) — lightweight and deterministic. + """ + import hashlib + + parts: list[str] = [] + # Walk model.config / model.args / direct attributes + for cfg_attr in ("config", "args", "model_config"): + cfg = getattr(model, cfg_attr, None) + if cfg is not None: + break + if cfg is None: + cfg = model # fallback: attributes on the model itself + + for key in ( + "num_hidden_layers", + "hidden_size", + "vocab_size", + "num_key_value_heads", + "head_dim", + "intermediate_size", + "model_type", + ): + val = getattr(cfg, key, None) + if val is not None: + parts.append(f"{key}={val}") + + fingerprint = hashlib.sha256("|".join(parts).encode()).hexdigest()[:16] + logger.debug(f"[model_fingerprint] {fingerprint} ({', '.join(parts)})") + return fingerprint + + class MemoryAwarePrefixCache: """ Prefix cache with memory-based eviction. @@ -586,6 +872,7 @@ def __init__( """ self._model_id = id(model) self._config = config or MemoryCacheConfig() + self._model_fingerprint = _compute_model_fingerprint(model) # OrderedDict maintains insertion order for LRU # Key: tuple(tokens), Value: _CacheEntry @@ -606,6 +893,9 @@ def __init__( # Track the match type from the last fetch() call self._last_match_type: str | None = None + # Optional SSD cold tier (set via set_ssd_tier()) + self._ssd_tier = None + logger.info( f"MemoryAwarePrefixCache initialized: " f"max_memory={self._max_memory / _BYTES_PER_MB:.1f}MB, " @@ -647,7 +937,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: self._last_match_type = "exact" cache_out = ( _dequantize_cache(entry.cache) - if self._config.kv_quantize + if self._config.needs_dequantize else entry.cache ) return cache_out, [] @@ -705,7 +995,10 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: excess = n_cached - n_requested has_non_trimmable = any( - not (hasattr(lc, "offset") and hasattr(lc, "keys")) + not ( + (hasattr(lc, "is_trimmable") and lc.is_trimmable()) + or (hasattr(lc, "offset") and hasattr(lc, "keys")) + ) for lc in best_super.cache ) @@ -722,7 +1015,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: self._last_match_type = "supersequence" trimmed_cache = ( _dequantize_cache(trimmed_cache) - if self._config.kv_quantize + if self._config.needs_dequantize else trimmed_cache ) return trimmed_cache, [] @@ -733,7 +1026,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: self._last_match_type = "supersequence" cache_out = ( _dequantize_cache(best_super.cache) - if self._config.kv_quantize + if self._config.needs_dequantize else best_super.cache ) return cache_out, [] @@ -747,7 +1040,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: self._last_match_type = "prefix" cache_out = ( _dequantize_cache(best_match.cache) - if self._config.kv_quantize + if self._config.needs_dequantize else best_match.cache ) return cache_out, remaining @@ -790,7 +1083,10 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: excess = len(best_lcp_entry.tokens) - best_lcp_length has_non_trimmable = any( - not (hasattr(lc, "offset") and hasattr(lc, "keys")) + not ( + (hasattr(lc, "is_trimmable") and lc.is_trimmable()) + or (hasattr(lc, "offset") and hasattr(lc, "keys")) + ) for lc in best_lcp_entry.cache ) logger.debug( @@ -822,7 +1118,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: self._last_match_type = "lcp" trimmed_cache = ( _dequantize_cache(trimmed_cache) - if self._config.kv_quantize + if self._config.needs_dequantize else trimmed_cache ) return trimmed_cache, remaining @@ -867,8 +1163,13 @@ def store( # Trim oversized KV arrays to actual used size cache = _trim_to_offset(cache) - # Quantize if enabled and sequence is long enough + # Compress KV cache for storage: TurboQuant (4.6x) or standard quantization (2x) if ( + self._config.turbo_kv_bits is not None + and len(tokens) >= self._config.kv_min_quantize_tokens + ): + cache = _turbo_quantize_cache(cache, self._config.turbo_kv_bits) + elif ( self._config.kv_quantize and len(tokens) >= self._config.kv_min_quantize_tokens ): @@ -946,7 +1247,11 @@ def _remove_from_sorted(self, key: tuple[int, ...]) -> None: self._sorted_keys.pop(idx) def _evict_lru(self) -> None: - """Evict the least recently used entry.""" + """Evict the least recently used entry. + + If an SSD tier is attached, the entry is spilled to disk instead + of being discarded. + """ if not self._entries: return @@ -958,9 +1263,14 @@ def _evict_lru(self) -> None: self._stats.entry_count = len(self._entries) self._stats.current_memory_bytes = self._current_memory + # Spill to SSD tier if available + if self._ssd_tier is not None: + self._ssd_tier.enqueue_spill(tokens_key, entry.cache, entry.memory_bytes) + logger.debug( f"[lru_evict] removed {len(tokens_key)} tokens, " f"freed {entry.memory_bytes / _BYTES_PER_MB:.2f}MB" + f"{' (spilled to SSD)' if self._ssd_tier is not None else ''}" ) def remove(self, tokens: list[int]) -> bool: @@ -1021,6 +1331,53 @@ def __contains__(self, tokens: list[int]) -> bool: """Check if tokens are cached.""" return tuple(tokens) in self._entries + def set_ssd_tier(self, ssd_tier) -> None: + """Attach an SSD cache tier for eviction spilling. + + When set, evicted entries are spilled to SSD instead of discarded. + + Args: + ssd_tier: An SSDCacheTier instance (or None to disable). + """ + self._ssd_tier = ssd_tier + if ssd_tier is not None: + logger.info("[memory_cache] SSD tier attached for eviction spilling") + + def check_ssd(self, tokens: list[int]) -> dict | None: + """Check if tokens have an SSD cache hit (without reading data). + + Returns metadata dict with 'match_type' ('exact' or 'prefix') if + found in SSD tier, None if not found. For prefix matches, the dict + also includes 'matched_tokens' (the count of tokens the SSD entry + covers). + + This is a fast synchronous call (SQLite lookup only). + The actual data read happens via the scheduler handoff. + """ + if self._ssd_tier is None: + return None + + tokens_key = tuple(tokens) + + # If already in RAM, no SSD needed + if tokens_key in self._entries: + return None + + # Check SSD tier — exact match first, then prefix + candidate = self._ssd_tier.lookup_ssd(tokens_key) + if candidate is not None: + candidate["match_type"] = "exact" + candidate["matched_tokens"] = len(tokens) + return candidate + + prefix = self._ssd_tier.lookup_ssd_prefix(tokens_key) + if prefix is not None: + prefix["match_type"] = "prefix" + prefix["matched_tokens"] = prefix["num_tokens"] + return prefix + + return None + # ----------------------------------------------------------------- # Disk persistence — survives server restarts # ----------------------------------------------------------------- @@ -1056,7 +1413,8 @@ def save_to_disk(self, cache_dir: str) -> bool: return False index = { - "version": 2, + "version": _CACHE_PERSIST_VERSION, + "model_fingerprint": self._model_fingerprint, "num_entries": len(self._entries), "total_memory_bytes": self._current_memory, "entries": [], @@ -1134,8 +1492,20 @@ def load_from_disk(self, cache_dir: str) -> int: index = json.load(f) version = index.get("version", 1) - if version < 2: - logger.warning(f"[cache_persist] unsupported version {version}, skipping") + if version != _CACHE_PERSIST_VERSION: + logger.warning( + f"[cache_persist] version mismatch: disk={version} " + f"current={_CACHE_PERSIST_VERSION}, discarding stale cache" + ) + return 0 + + disk_fp = index.get("model_fingerprint", "") + if disk_fp and disk_fp != self._model_fingerprint: + logger.warning( + f"[cache_persist] model fingerprint mismatch: " + f"disk={disk_fp} current={self._model_fingerprint}, " + f"discarding incompatible cache" + ) return 0 loaded = 0 diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index 6b1fbc9e7..09e43d0ac 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -17,6 +17,7 @@ """ import logging +import os import time from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Tuple @@ -31,6 +32,32 @@ logger = logging.getLogger(__name__) +def _processors_can_retire(processors: Optional[List[Callable]]) -> bool: + """True when any processor advertises a retire-to-content transition.""" + if os.getenv("VLLM_MLX_ENABLE_THINKING_RETIREMENT_RESUME") != "1": + return False + return bool(processors) and any( + isinstance(getattr(p, "is_retired", None), bool) for p in processors + ) + + +def _drop_retired_processors( + processors: Optional[List[Callable]], +) -> tuple[Optional[List[Callable]], int]: + """Drop retire-capable processors that have completed their work.""" + if not processors: + return processors, 0 + + remaining = [] + retired_count = 0 + for processor in processors: + if getattr(processor, "is_retired", False) is True: + retired_count += 1 + continue + remaining.append(processor) + return (remaining or None), retired_count + + class PrefillAbortedError(Exception): """Raised when a prefill is aborted due to client disconnect.""" @@ -39,6 +66,39 @@ def __init__(self, request_id: str): super().__init__(f"Prefill aborted for request {request_id}") +def _cache_eval_tensors(cache: List[Any]) -> List[Any]: + """Return realized tensors that break lazy cache graphs between chunks.""" + tensors: List[Any] = [] + for c in cache: + keys = getattr(c, "keys", None) + values = getattr(c, "values", None) + if keys is not None or values is not None: + if keys is not None: + tensors.append(keys) + if values is not None: + tensors.append(values) + continue + + try: + state = getattr(c, "state", None) + except AttributeError: + state = None + if state is None: + continue + if isinstance(state, (list, tuple)): + tensors.extend(s for s in state if s is not None) + else: + tensors.append(state) + return tensors + + +def _eval_prompt_cache(cache: List[Any]) -> None: + """Evaluate all cache tensors used by hybrid chunked prefill.""" + tensors = _cache_eval_tensors(cache) + if tensors: + mx.eval(*tensors) + + @dataclass class MLLMBatchRequest: """ @@ -60,6 +120,10 @@ class MLLMBatchRequest: min_p: float = 0.0 presence_penalty: float = 0.0 repetition_penalty: float = 1.0 + # Extra logits processors (e.g. JSON schema constrained decoding). + # Merged with built-in repetition/presence penalty processors in + # ``_prefill_batch``. + logits_processors: Optional[List[Callable]] = None # Processed inputs (set after vision preprocessing) input_ids: Optional[mx.array] = None @@ -177,13 +241,16 @@ def extend(self, other: "MLLMBatch") -> None: self.samplers = list(self_s) + list(other_s) # Extend cache - handle both BatchKVCache (.keys/.values) and - # ArraysCache (.cache list) from hybrid models like Qwen3.5 + # ArraysCache (.cache list) from hybrid models like Qwen3.5. Some + # cache integrations, such as quantized SDPA caches, expose state only + # through empty()/extend() and do not publish .keys. for c, o in zip(self.cache, other.cache): if c is not None and o is not None and hasattr(c, "extend"): try: has_kv = hasattr(c, "keys") and c.keys is not None has_arrays = hasattr(c, "cache") - if has_kv or has_arrays: + has_extendable_state = hasattr(c, "empty") and not c.empty() + if has_kv or has_arrays or has_extendable_state: c.extend(o) except Exception as e: logger.warning(f"Failed to extend cache: {e}") @@ -705,10 +772,17 @@ def _preprocess_request(self, request: MLLMBatchRequest) -> None: 3. Running vision encoder to get features Uses vision cache to skip processing for repeated images. + Idempotent: if input_ids is already set, returns immediately. Args: request: Request to preprocess """ + # Already preprocessed (e.g. by early executor offloading in + # _process_loop). Only skip for text-only requests; vision + # requests need pixel cache lookup even if input_ids was set. + if request.input_ids is not None and not request.images and not request.videos: + return + from mlx_vlm.utils import prepare_inputs tic = time.perf_counter() @@ -930,6 +1004,11 @@ def _run_chunked_text_prefill( return output.logits return output + logger.info( + f"[chunked_prefill] Starting {request.request_id[:12]}: " + f"{total} tokens, step={step}" + ) + # Process all chunks except the last processed = 0 chunk_count = 0 @@ -945,11 +1024,24 @@ def _run_chunked_text_prefill( chunk = input_ids[:, processed : processed + step] self.language_model(chunk, cache=cache) - mx.eval([c.state for c in cache]) + # Eval ALL cache types to break the lazy graph between chunks. + # ArraysCache (e.g. GatedDeltaNet) has .state; KVCache (full + # attention) has .keys/.values. Hybrid models like Qwen3.5 use + # both. Skipping either type lets the computation graph grow + # across chunks → OOM on long prompts. + _eval_prompt_cache(cache) processed += step chunk_count += 1 self._prefill_progress[request.request_id] = (processed, total) + # Log progress every 10 chunks so operators can see prefill + # is progressing (not hanging) during long prompts. + if chunk_count % 10 == 0: + logger.info( + f"[chunked_prefill] {request.request_id[:12]}: " + f"chunk {chunk_count}, {processed}/{total} tokens" + ) + # Release Metal buffer pool periodically. Full-attention layers # produce attention score buffers that grow each chunk (1024 × # growing_context). Old smaller buffers can't be reused, so the @@ -963,6 +1055,12 @@ def _run_chunked_text_prefill( request.vision_encoded = True self._prefill_progress[request.request_id] = (total, total) + if chunk_count > 0: + logger.info( + f"[chunked_prefill] Completed {request.request_id[:12]}: " + f"{total} tokens in {chunk_count + 1} chunks" + ) + if hasattr(output, "logits"): return output.logits return output @@ -1026,6 +1124,7 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: MLLMBatch ready for generation """ from mlx_lm.models.cache import make_prompt_cache + from mlx_lm.sample_utils import make_logits_processors, make_sampler tic = time.perf_counter() @@ -1059,6 +1158,60 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: # All requests failed return None + logits_processors_by_request: dict[str, Optional[List[Callable]]] = {} + samplers_by_request: dict[str, Optional[Callable]] = {} + for req in requests: + need_rep = req.repetition_penalty and req.repetition_penalty != 1.0 + need_pres = req.presence_penalty and req.presence_penalty != 0.0 + combined: List[Callable] = [] + if need_rep or need_pres: + lp_kwargs = {} + if need_rep: + lp_kwargs["repetition_penalty"] = req.repetition_penalty + if need_pres: + lp_kwargs["presence_penalty"] = req.presence_penalty + combined.extend(make_logits_processors(**lp_kwargs)) + logger.info( + f"[sampling] request={req.request_id[:12]} " + f"rep_penalty={req.repetition_penalty} " + f"pres_penalty={req.presence_penalty}" + ) + if req.logits_processors: + combined.extend(req.logits_processors) + logger.info( + f"[sampling] request={req.request_id[:12]} " + f"extra_logits_processors={len(req.logits_processors)}" + ) + logits_processors_by_request[req.request_id] = combined or None + + samplers_by_request[req.request_id] = make_sampler( + temp=req.temperature, + top_p=req.top_p, + top_k=req.top_k, + min_p=req.min_p, + ) + logger.info( + f"[sampling] request={req.request_id[:12]} " + f"temp={req.temperature} top_p={req.top_p} " + f"top_k={req.top_k} min_p={req.min_p}" + ) + + def _sample_first_token(req: MLLMBatchRequest, logits: mx.array): + sample_logits = logits + processors = logits_processors_by_request.get(req.request_id) + if processors: + empty_tokens = mx.array([], dtype=mx.uint32) + for processor in processors: + sample_logits = processor(empty_tokens, sample_logits) + + logprobs = sample_logits - mx.logsumexp( + sample_logits, axis=-1, keepdims=True + ) + sampler = samplers_by_request.get(req.request_id) or self.sampler + sampled = sampler(logprobs) + mx.eval(sampled, logprobs) + return sampled, logprobs + total_prompt_tokens = sum( req.input_ids.size if req.input_ids is not None else 1 for req in requests ) @@ -1177,7 +1330,8 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: chunk = remaining[:, processed : processed + step] self.language_model(chunk, cache=request_cache) - mx.eval([c.state for c in request_cache]) + # Eval ALL cache types (see _run_chunked_text_prefill) + _eval_prompt_cache(request_cache) processed += step chunk_count += 1 self._prefill_progress[req.request_id] = ( @@ -1198,11 +1352,8 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: logits = logits.logits last_logits = logits[:, -1, :] - logprobs = last_logits - mx.logsumexp( - last_logits, axis=-1, keepdims=True - ) - sampled = self.sampler(logprobs) - mx.eval(sampled, logprobs) + + sampled, logprobs = _sample_first_token(req, last_logits) first_tokens.append(sampled.item()) all_logprobs.append(logprobs.squeeze(0)) @@ -1236,11 +1387,8 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: logits = logits.logits last_logits = logits[:, -1, :] - logprobs = last_logits - mx.logsumexp( - last_logits, axis=-1, keepdims=True - ) - sampled = self.sampler(logprobs) - mx.eval(sampled, logprobs) + + sampled, logprobs = _sample_first_token(req, last_logits) first_tokens.append(sampled.item()) all_logprobs.append(logprobs.squeeze(0)) @@ -1266,14 +1414,10 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: else: logits = self._run_vision_encoding(req, cache=request_cache) - # Extract last token logits and sample + # Extract last token logits last_logits = logits[:, -1, :] - logprobs = last_logits - mx.logsumexp( - last_logits, axis=-1, keepdims=True - ) - sampled = self.sampler(logprobs) - mx.eval(sampled, logprobs) + sampled, logprobs = _sample_first_token(req, last_logits) first_tokens.append(sampled.item()) all_logprobs.append(logprobs.squeeze(0)) @@ -1353,50 +1497,12 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: # Create initial y (first generated tokens) y = mx.array(first_tokens) - # Build per-request logits processors (repetition_penalty, presence_penalty) - from mlx_lm.sample_utils import make_logits_processors, make_sampler - - batch_logits_processors = [] - has_any_lp = False - for req in requests: - need_rep = req.repetition_penalty and req.repetition_penalty != 1.0 - need_pres = req.presence_penalty and req.presence_penalty != 0.0 - if need_rep or need_pres: - lp_kwargs = {} - if need_rep: - lp_kwargs["repetition_penalty"] = req.repetition_penalty - if need_pres: - lp_kwargs["presence_penalty"] = req.presence_penalty - lp = make_logits_processors(**lp_kwargs) - batch_logits_processors.append(lp) - has_any_lp = True - logger.info( - f"[sampling] request={req.request_id[:12]} " - f"rep_penalty={req.repetition_penalty} " - f"pres_penalty={req.presence_penalty}" - ) - else: - batch_logits_processors.append(None) - - # Build per-request samplers for top_k/min_p - batch_samplers = [] - has_any_sampler = False - for req in requests: - if req.top_k != 0 or req.min_p != 0.0: - s = make_sampler( - temp=req.temperature, - top_p=req.top_p, - top_k=req.top_k, - min_p=req.min_p, - ) - batch_samplers.append(s) - has_any_sampler = True - logger.info( - f"[sampling] request={req.request_id[:12]} " - f"top_k={req.top_k} min_p={req.min_p}" - ) - else: - batch_samplers.append(None) + batch_logits_processors = [ + logits_processors_by_request.get(req.request_id) for req in requests + ] + has_any_lp = any(batch_logits_processors) + batch_samplers = [samplers_by_request.get(req.request_id) for req in requests] + has_any_sampler = any(batch_samplers) self._stats.prompt_time += time.perf_counter() - tic @@ -1455,10 +1561,15 @@ def _step( for e in range(logits.shape[0]): sample_logits = logits[e : e + 1] if logits_processors[e]: + # Build full context: output_tokens + current input token. + # ``output_tokens[e]`` lacks the current step's input token + # because it hasn't been appended yet; adding it here gives + # logits processors (e.g. JSON schema enforcer) accurate + # context about what has been generated so far. + cur_tok = int(input_tokens[e, 0]) + full_context = output_tokens[e] + [cur_tok] for processor in logits_processors[e]: - sample_logits = processor( - mx.array(output_tokens[e]), sample_logits - ) + sample_logits = processor(mx.array(full_context), sample_logits) processed_logits.append(sample_logits) logits = mx.concatenate(processed_logits, axis=0) @@ -1584,11 +1695,13 @@ def _next(self) -> List[MLLMBatchResponse]: return error_responses y, logprobs = batch.y, batch.logprobs - output_tokens = ( - [req.output_tokens for req in batch.requests] - if batch.logits_processors - else None - ) + output_tokens = None + if batch.logits_processors: + y_list = y.tolist() + output_tokens = [ + list(req.output_tokens) + [token] + for req, token in zip(batch.requests, y_list) + ] batch.y, batch.logprobs = self._step( y[:, None], batch.cache, @@ -1626,6 +1739,26 @@ def _next(self) -> List[MLLMBatchResponse]: req.num_tokens = num_tok req.output_tokens.append(token) + if batch.logits_processors and _processors_can_retire( + batch.logits_processors[i] + ): + remaining_processors, retired_count = _drop_retired_processors( + batch.logits_processors[i] + ) + if retired_count > 0: + # Keep the per-request slot but replace an empty processor + # stack with None. The next `_mtp_step` uses any([None]) == + # False, so a fully retired request becomes MTP-eligible + # without changing batch alignment. + batch.logits_processors[i] = remaining_processors + logger.info( + "[MTP-MLLM] request=%s retired %d processor(s); " + "mtp_eligible_next_step=%s", + request_id[:12], + retired_count, + remaining_processors is None, + ) + finish_reason = None cache_fn = None @@ -1789,14 +1922,33 @@ def _mtp_step( ) -> Tuple[mx.array, List[mx.array]]: """Extended _step with MTP always-advance strategy.""" batch_size = input_tokens.shape[0] + active_requests = ( + list(batch_gen.active_batch.requests) + if batch_gen.active_batch is not None + else [] + ) + has_non_greedy_sampling = any( + getattr(req, "temperature", 0.0) not in (0, 0.0) + or getattr(req, "top_p", 1.0) < 1.0 + or getattr(req, "top_k", 0) != 0 + or getattr(req, "min_p", 0.0) != 0.0 + for req in active_requests + ) # Prefill guard: skip MTP for multi-token input or when no active batch # Also skip MTP when batch has multiple active requests (MTP overhead - # hurts aggregate throughput in concurrent scenarios) + # hurts aggregate throughput in concurrent scenarios). The current + # verifier is only correctness-safe for greedy decoding with no + # request-local logits processors. Accepted drafts are emitted directly + # from the greedy draft/argmax verify path; they do not pass through the + # request-local sampler. Non-greedy decoding needs a sampler-aware + # verifier before this guard can be safely relaxed. if ( input_tokens.shape[1] > 1 or batch_gen.active_batch is None or len(batch_gen.active_batch) > 1 + or has_non_greedy_sampling + or (logits_processors is not None and any(logits_processors)) ): _skip_state[0] = None return _orig_step( @@ -2084,8 +2236,14 @@ def _mtp_next() -> List[MLLMBatchResponse]: batch_gen._step = _mtp_step batch_gen._next = _mtp_next + if num_draft_tokens != 1: + logger.warning( + "[MTP-MLLM] num_draft_tokens=%d requested, but the current batched " + "MLLM MTP path drafts exactly one token per verify step", + num_draft_tokens, + ) total = _mtp_stats logger.info( f"[MTP-MLLM] installed with num_draft_tokens={num_draft_tokens}, " - f"always-advance verified mode" + f"effective_draft_tokens=1, always-advance verified mode" ) diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index bf6af19e1..0b41d57fe 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -19,7 +19,6 @@ """ import asyncio -import concurrent.futures import logging import time import uuid @@ -36,6 +35,7 @@ MLLMBatchRequest, MLLMBatchResponse, ) +from .mlx_streams import bind_generation_streams from .multimodal_processor import MultimodalProcessor from .request import RequestOutput, RequestStatus, SamplingParams @@ -346,7 +346,9 @@ def add_request( temperature: Sampling temperature top_p: Top-p sampling request_id: Optional custom request ID - **kwargs: Additional generation parameters + **kwargs: Additional generation parameters. ``logits_processors`` + — list of callables ``(tokens, logits) -> logits`` applied + during sampling (e.g. constrained JSON decoding). Returns: Request ID for tracking @@ -362,6 +364,7 @@ def add_request( min_p=kwargs.pop("min_p", 0.0), presence_penalty=kwargs.pop("presence_penalty", 0.0), repetition_penalty=kwargs.pop("repetition_penalty", 1.0), + logits_processors=kwargs.pop("logits_processors", None), ) request = MLLMRequest( @@ -509,6 +512,7 @@ def _schedule_waiting(self) -> List[MLLMRequest]: min_p=request.sampling_params.min_p, presence_penalty=request.sampling_params.presence_penalty, repetition_penalty=request.sampling_params.repetition_penalty, + logits_processors=request.sampling_params.logits_processors, ) batch_requests.append(batch_req) @@ -781,37 +785,91 @@ async def stop(self) -> None: async def _process_loop(self) -> None: """Main async processing loop. - Uses a thread pool executor for steps that involve prefill - (waiting requests or partial prefill in progress) so that the - event loop stays responsive for health checks and other HTTP - endpoints. Decode-only steps are fast (<3 ms) and run inline. + MLLM models are loaded on the server/event-loop thread, so their MLX + arrays and cache state must be consumed on that same thread. Unlike + the text-only EngineCore path, moving MLLM prefill to a worker crosses + MLX stream ownership and can fail with "no Stream in current thread". + + Text-only preprocessing (Jinja2 template rendering + tokenization) is + run BEFORE ``step()`` with ``await asyncio.sleep(0)`` yields between + each request. This prevents long preprocessing (10-30+ s for 40K+ + token conversations) from blocking health checks and new connections. """ - _executor = concurrent.futures.ThreadPoolExecutor( - max_workers=1, thread_name_prefix="mllm-step" - ) + streams_bound = False + + def _ensure_streams_bound() -> None: + nonlocal streams_bound + if not streams_bound: + bind_generation_streams() + streams_bound = True + loop = asyncio.get_running_loop() while self._running: try: + # --- Early preprocessing phase --- + # Run text-only preprocessing (Jinja2 template rendering + + # tokenization) in a thread-pool executor so the event loop + # stays responsive for health checks, new connections, and + # active streaming requests. Preprocessing is CPU-bound + # (no MLX GPU work) and HuggingFace tokenizers are + # thread-safe, so this is safe to offload. + bg = self.batch_generator + if bg is not None: + for req in list(getattr(bg, "unprocessed_requests", ())): + if req.input_ids is None and not req.images and not req.videos: + try: + tic = time.perf_counter() + await loop.run_in_executor( + None, bg._preprocess_request, req + ) + elapsed = time.perf_counter() - tic + if elapsed > 1.0: + n_tok = ( + req.input_ids.size + if req.input_ids is not None + else 0 + ) + logger.info( + f"Preprocessing {req.request_id[:12]}" + f": {n_tok} tokens in {elapsed:.2f}s" + ) + except Exception as e: + logger.error( + f"Early preprocessing failed for " + f"{req.request_id}: {e}" + ) + + # --- Step phase --- if self.has_requests(): - has_waiting = self.get_num_waiting() > 0 - has_partial = ( - self.batch_generator is not None - and getattr(self.batch_generator, "_partial", None) is not None - ) - needs_executor = has_waiting or has_partial - - if needs_executor: - await loop.run_in_executor(_executor, self.step) - else: - self.step() + _ensure_streams_bound() + tic = time.perf_counter() + self.step() + elapsed = time.perf_counter() - tic + if elapsed > 2.0: + logger.warning( + f"Slow MLLM step: {elapsed:.2f}s " + f"(waiting={len(self.waiting)}, " + f"running={len(self.running)})" + ) + # Yield multiple event-loop cycles so that pending + # HTTP health checks can complete. A single + # asyncio.sleep() gives only ONE _run_once() cycle, + # but an HTTP request needs ~3 cycles minimum: + # 1. accept TCP connection + # 2. read HTTP request / parse headers + # 3. run handler / write response + # Using repeated asyncio.sleep(0) gives many cycles + # with negligible wall-clock overhead (<1ms total). + n_yields = 10 if elapsed > 1.0 else 5 + for _ in range(n_yields): await asyncio.sleep(0) else: # No work, wait a bit await asyncio.sleep(0.01) except asyncio.CancelledError: - break + raise except Exception as e: logger.error(f"Error in MLLM process loop: {e}", exc_info=True) await asyncio.sleep(0.1) @@ -880,10 +938,11 @@ async def stream_outputs( if output is None: finished_normally = True break - yield output if output.finished: finished_normally = True + yield output break + yield output finally: if not finished_normally: logger.info(f"Aborting orphaned MLLM request {request_id}") @@ -1064,6 +1123,23 @@ def get_stats(self) -> Dict[str, Any]: return stats + def clear_runtime_caches(self) -> Dict[str, bool]: + """Clear runtime caches without resetting scheduler/request state.""" + cleared = { + "vision_cache": False, + "prefix_cache": False, + } + if self.vision_cache: + self.vision_cache.clear() + cleared["vision_cache"] = True + if ( + self.batch_generator is not None + and self.batch_generator.prefix_cache is not None + ): + self.batch_generator.prefix_cache.clear() + cleared["prefix_cache"] = True + return cleared + def reset(self) -> None: """Reset the scheduler state.""" # Abort all requests diff --git a/vllm_mlx/mlx_streams.py b/vllm_mlx/mlx_streams.py new file mode 100644 index 000000000..16dc1130f --- /dev/null +++ b/vllm_mlx/mlx_streams.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Helpers for binding MLX generation streams to worker threads.""" + +import importlib +from collections.abc import Iterable + +import mlx.core as mx + + +def bind_generation_streams( + module_names: Iterable[str] = ("mlx_lm.generate", "mlx_vlm.generate"), +) -> object: + """Bind mlx-lm/mlx-vlm generation streams to the current thread. + + MLX streams are thread-local. If a model is loaded on one thread and + generation runs on another, module-level generation streams created during + import can point at a stream that does not exist in the worker thread. + """ + default_stream = mx.new_stream(mx.default_device()) + mx.set_default_stream(default_stream) + for module_name in module_names: + try: + module = importlib.import_module(module_name) + except ImportError: + continue + if hasattr(module, "generation_stream"): + module.generation_stream = default_stream + return default_stream diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index 71f0af3ca..595f020aa 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -7,8 +7,12 @@ """ import logging +from collections.abc import Iterator from dataclasses import dataclass -from typing import Iterator +from typing import TYPE_CHECKING, Union + +if TYPE_CHECKING: + import mlx.core as mx logger = logging.getLogger(__name__) @@ -52,6 +56,7 @@ def __init__( tokenizer_name: str | None = None, trust_remote_code: bool = False, mtp: bool = False, + mtp_num_draft_tokens: int = 1, ): """ Initialize the MLX language model. @@ -61,11 +66,13 @@ def __init__( tokenizer_name: Optional separate tokenizer name trust_remote_code: Whether to trust remote code mtp: Enable native MTP speculative decoding (model must have MTP head) + mtp_num_draft_tokens: Draft tokens per speculative MTP step """ self.model_name = model_name self.tokenizer_name = tokenizer_name or model_name self.trust_remote_code = trust_remote_code self._mtp = mtp + self._mtp_num_draft_tokens = mtp_num_draft_tokens self.model = None self.tokenizer = None @@ -98,11 +105,10 @@ def load(self) -> None: self._loaded = True logger.info(f"Model loaded successfully: {self.model_name}") - except ImportError: + except ImportError as err: raise ImportError( - "mlx-lm is required for LLM inference. " - "Install with: pip install mlx-lm" - ) + "mlx-lm is required for LLM inference. Install with: pip install mlx-lm" + ) from err except Exception as e: logger.error(f"Failed to load model: {e}") raise @@ -151,6 +157,7 @@ def generate( presence_penalty: float = 0.0, repetition_penalty: float = 1.0, stop: list[str] | None = None, + logits_processors: list | None = None, **kwargs, ) -> GenerationOutput: """ @@ -166,6 +173,9 @@ def generate( presence_penalty: Additive penalty for token presence repetition_penalty: Multiplicative penalty for repeating tokens stop: List of stop sequences + logits_processors: Optional externally-supplied logits processors + (e.g. JSON schema constrained decoding). Merged with built-in + penalty processors. Returns: GenerationOutput with generated text and tokens @@ -177,9 +187,13 @@ def generate( # Create sampler and logits processors with full Unsloth params sampler = self._create_sampler(temperature, top_p, top_k, min_p) - logits_processors = self._create_logits_processors( + penalty_processors = self._create_logits_processors( presence_penalty, repetition_penalty ) + # Merge any externally-provided logits_processors with penalty processors + all_processors = penalty_processors or [] + if logits_processors: + all_processors = list(logits_processors) + all_processors # Generate text output_text = generate( @@ -188,7 +202,7 @@ def generate( prompt=prompt, max_tokens=max_tokens, sampler=sampler, - logits_processors=logits_processors, + logits_processors=all_processors if all_processors else None, verbose=False, ) @@ -206,7 +220,7 @@ def generate( def stream_generate( self, - prompt: str, + prompt: Union[str, "mx.array", list[int]], max_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.9, @@ -216,13 +230,14 @@ def stream_generate( repetition_penalty: float = 1.0, stop: list[str] | None = None, logits_processors: list | None = None, + prompt_cache=None, **kwargs, ) -> Iterator[StreamingOutput]: """ Stream text generation token by token. Args: - prompt: Input prompt text + prompt: Input prompt text, token array, or token id list max_tokens: Maximum number of tokens to generate temperature: Sampling temperature (0 = greedy) top_p: Top-p (nucleus) sampling parameter @@ -231,6 +246,7 @@ def stream_generate( presence_penalty: Additive penalty for token presence repetition_penalty: Multiplicative penalty for repeating tokens stop: List of stop sequences + prompt_cache: Pre-populated KV cache (e.g. from SpecPrefill) Yields: StreamingOutput for each generated token @@ -251,25 +267,32 @@ def stream_generate( all_processors = (logits_processors or []) + (penalty_processors or []) # Count prompt tokens once upfront - num_prompt_tokens = len(self.tokenizer.encode(prompt)) + if isinstance(prompt, str): + num_prompt_tokens = len(self.tokenizer.encode(prompt)) + else: + num_prompt_tokens = len(prompt) - token_count = 0 accumulated_text = "" mtp_kwargs = {} if self._mtp: mtp_kwargs["mtp"] = True - - for response in stream_generate( - self.model, - self.tokenizer, - prompt=prompt, - max_tokens=max_tokens, - sampler=sampler, - logits_processors=all_processors, - **mtp_kwargs, + mtp_kwargs["num_draft_tokens"] = self._mtp_num_draft_tokens + if prompt_cache is not None: + mtp_kwargs["prompt_cache"] = prompt_cache + + for token_count, response in enumerate( + stream_generate( + self.model, + self.tokenizer, + prompt=prompt, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=all_processors, + **mtp_kwargs, + ), + start=1, ): - token_count += 1 # response.text is the new token text (not accumulated) new_text = response.text accumulated_text += new_text @@ -305,6 +328,7 @@ def chat( temperature: float = 0.7, top_p: float = 0.9, tools: list | None = None, + chat_template_kwargs: dict | None = None, **kwargs, ) -> GenerationOutput: """ @@ -335,6 +359,8 @@ def chat( # Add tools if provided and supported if tools: template_kwargs["tools"] = tools + if chat_template_kwargs: + template_kwargs.update(chat_template_kwargs) try: prompt = self.tokenizer.apply_chat_template( @@ -342,8 +368,10 @@ def chat( **template_kwargs, ) except TypeError: - # Tokenizer doesn't support tools parameter - del template_kwargs["tools"] + # Tokenizer doesn't support all requested template kwargs + template_kwargs.pop("tools", None) + for key in (chat_template_kwargs or {}).keys(): + template_kwargs.pop(key, None) prompt = self.tokenizer.apply_chat_template( messages, **template_kwargs, diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index ec761c3f2..dae22f99b 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -15,15 +15,17 @@ import atexit import base64 +import ipaddress import logging import math import os +import socket import tempfile import threading from collections.abc import Iterator from dataclasses import dataclass, field from pathlib import Path -from urllib.parse import urlparse +from urllib.parse import urljoin, urlparse import numpy as np import requests @@ -115,6 +117,12 @@ class FileSizeExceededError(Exception): pass +class UnsafeRemoteURLError(ValueError): + """Raised when a remote media URL targets an unsafe destination.""" + + pass + + @dataclass class MultimodalInput: """Input for multimodal generation.""" @@ -182,6 +190,83 @@ def decode_base64_image( return base64.b64decode(base64_string) +def _validate_url_safety(url: str) -> None: + """Reject remote URLs that target local or private network resources.""" + parsed = urlparse(url) + if parsed.scheme not in {"http", "https"}: + raise UnsafeRemoteURLError( + f"Unsupported remote media URL scheme: {parsed.scheme or ''}" + ) + + hostname = parsed.hostname + if not hostname: + raise UnsafeRemoteURLError("Remote media URL must include a hostname") + + if hostname == "localhost" or hostname.endswith(".localhost"): + raise UnsafeRemoteURLError( + f"Remote media URL targets a blocked host: {hostname}" + ) + + try: + resolved_ips = [ipaddress.ip_address(hostname)] + except ValueError: + try: + addrinfo = socket.getaddrinfo( + hostname, + parsed.port or (443 if parsed.scheme == "https" else 80), + type=socket.SOCK_STREAM, + ) + except socket.gaierror as exc: + raise UnsafeRemoteURLError( + f"Failed to resolve remote media host {hostname}: {exc}" + ) from exc + resolved_ips = [ipaddress.ip_address(info[4][0]) for info in addrinfo] + + blocked_ips = [str(ip) for ip in resolved_ips if not ip.is_global] + if blocked_ips: + raise UnsafeRemoteURLError( + f"Remote media URL resolves to blocked address(es): {', '.join(sorted(set(blocked_ips)))}" + ) + + +def _request_with_safe_redirects( + method: str, + url: str, + *, + timeout: int, + headers: dict[str, str], + stream: bool = False, + max_redirects: int = 5, +): + """Issue a requests call while validating every redirect target.""" + current_url = url + for _ in range(max_redirects + 1): + _validate_url_safety(current_url) + response = requests.request( + method, + current_url, + timeout=timeout, + headers=headers, + allow_redirects=False, + verify=True, + stream=stream, + ) + if not response.is_redirect and not response.is_permanent_redirect: + return response + + location = response.headers.get("location") + response.close() + if not location: + raise UnsafeRemoteURLError( + f"Remote media URL redirect missing Location header: {current_url}" + ) + current_url = urljoin(current_url, location) + + raise UnsafeRemoteURLError( + f"Remote media URL exceeded redirect limit ({max_redirects}): {url}" + ) + + def download_image(url: str, timeout: int = 30, max_size: int = MAX_IMAGE_SIZE) -> str: """ Download image from URL and return local path. @@ -203,8 +288,8 @@ def download_image(url: str, timeout: int = 30, max_size: int = MAX_IMAGE_SIZE) # First, make a HEAD request to check Content-Length try: - head_response = requests.head( - url, timeout=timeout, headers=headers, allow_redirects=True, verify=True + head_response = _request_with_safe_redirects( + "HEAD", url, timeout=timeout, headers=headers ) content_length = head_response.headers.get("content-length") if content_length and int(content_length) > max_size: @@ -216,8 +301,8 @@ def download_image(url: str, timeout: int = 30, max_size: int = MAX_IMAGE_SIZE) # HEAD request failed, proceed with GET and check during download pass - response = requests.get( - url, timeout=timeout, headers=headers, stream=True, verify=True + response = _request_with_safe_redirects( + "GET", url, timeout=timeout, headers=headers, stream=True ) response.raise_for_status() @@ -241,7 +326,7 @@ def download_image(url: str, timeout: int = 30, max_size: int = MAX_IMAGE_SIZE) ext = ".webp" else: # Try to get from URL - path = urlparse(url).path + path = urlparse(response.url).path ext = Path(path).suffix or ".jpg" # Save to temp file with size checking during download @@ -293,8 +378,8 @@ def download_video(url: str, timeout: int = 120, max_size: int = MAX_VIDEO_SIZE) # First, make a HEAD request to check Content-Length try: - head_response = requests.head( - url, timeout=timeout, headers=headers, allow_redirects=True, verify=True + head_response = _request_with_safe_redirects( + "HEAD", url, timeout=timeout, headers=headers ) content_length = head_response.headers.get("content-length") if content_length and int(content_length) > max_size: @@ -306,8 +391,8 @@ def download_video(url: str, timeout: int = 120, max_size: int = MAX_VIDEO_SIZE) # HEAD request failed, proceed with GET and check during download pass - response = requests.get( - url, timeout=timeout, headers=headers, stream=True, verify=True + response = _request_with_safe_redirects( + "GET", url, timeout=timeout, headers=headers, stream=True ) response.raise_for_status() @@ -333,7 +418,7 @@ def download_video(url: str, timeout: int = 120, max_size: int = MAX_VIDEO_SIZE) ext = ".mkv" else: # Try to get from URL - path = urlparse(url).path + path = urlparse(response.url).path ext = Path(path).suffix or ".mp4" # Save to temp file (stream for larger files) with size checking @@ -421,7 +506,6 @@ def process_video_input(video: str | dict) -> str: Process video input in various formats and return local path. Supports: - - Local file path - URL (http/https) - Base64 encoded string (data:video/mp4;base64,...) - OpenAI format dict: {"url": "..."} or {"url": "data:video/...;base64,..."} @@ -442,10 +526,6 @@ def process_video_input(video: str | dict) -> str: if not video: raise ValueError("Empty video input") - # Check if it's a local file - if Path(video).exists(): - return video - # Check if it's a URL if is_url(video): return download_video(video) @@ -454,7 +534,9 @@ def process_video_input(video: str | dict) -> str: if is_base64_video(video): return decode_base64_video(video) - raise ValueError(f"Cannot process video: {video[:50]}...") + raise ValueError( + "Unsupported video input. Only http(s) URLs and data:video base64 payloads are allowed." + ) # Cache for base64 images to avoid re-saving the same image @@ -505,7 +587,6 @@ def process_image_input(image: str | dict) -> str: Process image input in various formats and return local path. Supports: - - Local file path - URL (http/https) - Base64 encoded string - OpenAI format dict: {"url": "..."} or {"url": "data:image/...;base64,..."} @@ -528,11 +609,9 @@ def process_image_input(image: str | dict) -> str: if is_url(image): return download_image(image) - # Check if it's a local file (only for short strings that could be paths) - if len(image) < 4096 and Path(image).exists(): - return image - - raise ValueError(f"Cannot process image: {image[:50]}...") + raise ValueError( + "Unsupported image input. Only http(s) URLs and data:image base64 payloads are allowed." + ) def round_by_factor(x: int, factor: int) -> int: @@ -690,7 +769,7 @@ class MLXMultimodalLM: def __init__( self, model_name: str, - trust_remote_code: bool = True, + trust_remote_code: bool = False, enable_cache: bool = True, cache_size: int = 50, ): @@ -758,7 +837,7 @@ def get_tokenizer(self): return self.processor.tokenizer def _prepare_images(self, images: list) -> list[str]: - """Process image inputs and return local file paths.""" + """Process remote/base64 image inputs into local temp file paths.""" processed = [] for img in images: try: @@ -778,7 +857,6 @@ def _prepare_video( Process video input and extract frames. Supports: - - Local file paths - URLs (http/https) - will be downloaded - Base64 encoded videos (data:video/mp4;base64,...) - OpenAI format dicts: {"url": "..."} or {"video_url": {"url": "..."}} @@ -973,7 +1051,7 @@ def _translate_messages_for_native_video( ) -> list[dict]: """Translate OpenAI API format messages to process_vision_info format. - Converts video_url/video types and resolves URLs/base64 to local paths. + Converts video_url/video types and resolves remote/base64 inputs to local paths. Images are preserved as-is (process_vision_info handles them). """ translated = [] @@ -1074,8 +1152,8 @@ def generate( Args: prompt: Text prompt/question - images: List of image paths, URLs, or base64 strings - videos: List of video inputs (paths, URLs, base64, or OpenAI format dicts) + images: List of image URLs or base64 strings + videos: List of video inputs (URLs, base64, or OpenAI format dicts) audio: List of audio file paths max_tokens: Maximum tokens to generate temperature: Sampling temperature @@ -1319,9 +1397,10 @@ def chat( from mlx_vlm import generate from mlx_vlm.prompt_utils import get_chat_template - # Extract text and images from messages - # Build chat_messages for multi-turn support WITH proper image tokens per message + # Extract text, images and audio from messages + # Build chat_messages for multi-turn support WITH proper image/audio tokens per message all_image_urls = [] # Raw URLs/paths to process later + all_audio_urls = [] # Raw audio URLs/paths to process later chat_messages = [] # List of properly formatted messages for chat template logger.info(f"MLLM.chat() called with {len(messages)} messages") @@ -1405,17 +1484,27 @@ def chat( ) msg_image_count += 1 + elif item_type == "audio_url": + aud_url = item.get("audio_url", {}) + if isinstance(aud_url, str): + all_audio_urls.append(aud_url) + else: + all_audio_urls.append(aud_url.get("url", "")) + # Add video frame count to image count for this message msg_image_count += _msg_video_frame_counts.get(msg_idx, 0) + msg_audio_count = len(all_audio_urls) - # Build properly structured message for Qwen3-VL-MoE - # Format: {"role": "...", "content": [{"type": "image"}, ..., {"type": "text", "text": "..."}]} - if msg_text or msg_image_count > 0: - if role == "user" and msg_image_count > 0: - # User message WITH images - build content array with image tokens FIRST + # Build properly structured message + # Format: {"role": "...", "content": [{"type": "image"}, ..., {"type": "audio"}, ..., {"type": "text", "text": "..."}]} + if msg_text or msg_image_count > 0 or msg_audio_count > 0: + if role == "user" and (msg_image_count > 0 or msg_audio_count > 0): + # User message WITH images/audio - build content array with media tokens FIRST content_list = [] for _ in range(msg_image_count): content_list.append({"type": "image"}) + for _ in range(msg_audio_count): + content_list.append({"type": "audio"}) content_list.append( {"type": "text", "text": msg_text, "content": msg_text} ) @@ -1443,7 +1532,7 @@ def chat( # Apply chat template directly - messages are already properly structured logger.info( - f"Applying chat template with {len(chat_messages)} messages, {len(all_images)} images" + f"Applying chat template with {len(chat_messages)} messages, {len(all_images)} images, {len(all_audio_urls)} audios" ) for i, cm in enumerate(chat_messages): content_preview = str(cm.get("content", ""))[:80] @@ -1586,6 +1675,7 @@ def chat( self.processor, formatted_prompt, all_images if all_images else None, + audio=all_audio_urls if all_audio_urls else None, max_tokens=max_tokens, temp=temperature, verbose=False, @@ -1996,9 +2086,6 @@ def describe_video( Video description text Example: - # Local file - model.describe_video("video.mp4") - # URL model.describe_video("https://example.com/video.mp4") diff --git a/vllm_mlx/multimodal_processor.py b/vllm_mlx/multimodal_processor.py index a5c861216..aae33736e 100644 --- a/vllm_mlx/multimodal_processor.py +++ b/vllm_mlx/multimodal_processor.py @@ -108,8 +108,8 @@ def process( Args: prompt: Text prompt (already formatted with chat template) - images: List of image paths, URLs, or base64 strings - videos: List of video inputs + images: List of image URLs or base64 strings + videos: List of video URLs or base64 inputs video_fps: FPS for video frame extraction video_max_frames: Max frames per video add_special_tokens: Whether to add special tokens diff --git a/vllm_mlx/patches/qwen3_5_mllm.py b/vllm_mlx/patches/qwen3_5_mllm.py index c592928da..b8fd55014 100644 --- a/vllm_mlx/patches/qwen3_5_mllm.py +++ b/vllm_mlx/patches/qwen3_5_mllm.py @@ -89,6 +89,18 @@ def _patched_call( # but kv_seq_len must be int for mask[..., :kv_seq_len]. _offset = _cache_offset_to_int(cache) + # mlx-vlm caches position_ids on the language model object. + # If a subsequent request has a different prompt length, stale + # position_ids can be shorter than the current chunk and crash + # rotary application with a broadcast shape mismatch. + if position_ids is not None and position_ids.shape[-1] != L: + logger.debug( + "[Qwen3.5 patch] Recomputing stale position_ids: got %s, expected %s", + position_ids.shape[-1], + L, + ) + position_ids = None + if position_ids is None: kv_seq_len += _offset + 1 position_ids = mx.arange(_offset, _offset + L) @@ -115,6 +127,6 @@ def _patched_call( return self.o_proj(output * mx.sigmoid(gate)) Qwen3_5Attention.__call__ = _patched_call - Qwen3_5Attention._batch_patched = True + setattr(Qwen3_5Attention, "_batch_patched", True) logger.info("[Qwen3.5 patch] Attention patched for BatchKVCache support") return True diff --git a/vllm_mlx/patches/qwen3_5_mtp.py b/vllm_mlx/patches/qwen3_5_mtp.py index 3d5f3e632..758ddaa51 100644 --- a/vllm_mlx/patches/qwen3_5_mtp.py +++ b/vllm_mlx/patches/qwen3_5_mtp.py @@ -25,6 +25,36 @@ logger = logging.getLogger(__name__) +_QWEN_MTP_RMSNORM_WEIGHT_SUFFIXES = ( + "input_layernorm.weight", + "post_attention_layernorm.weight", + "q_norm.weight", + "k_norm.weight", + "pre_fc_norm_hidden.weight", + "pre_fc_norm_embedding.weight", + "norm.weight", +) + + +def _is_qwen_mtp_rmsnorm_weight(key: str, weight) -> bool: + """Return True for MTP RMSNorm weights that use Qwen's offset convention.""" + return weight.ndim == 1 and any( + key.endswith(suffix) for suffix in _QWEN_MTP_RMSNORM_WEIGHT_SUFFIXES + ) + + +def _apply_qwen_mtp_rmsnorm_offset_fixups(mtp_weights: dict) -> int: + """Apply Qwen raw-offset RMSNorm fixups without double-shifting MLX weights.""" + norm_fixup_count = 0 + for key, weight in list(mtp_weights.items()): + if not _is_qwen_mtp_rmsnorm_weight(key, weight): + continue + mean_val = weight.mean().item() + if mean_val < 0.5: + mtp_weights[key] = weight + 1.0 + norm_fixup_count += 1 + return norm_fixup_count + def _fixup_moe_mtp(mtp, inner_model, loaded_keys: set, mx) -> None: """Fix missing weights in MoE MTP module. @@ -217,7 +247,7 @@ def __init__(self, args, n_layers): scales_key = key.replace(".weight", ".scales") biases_key = key.replace(".weight", ".biases") - if scales_key in raw_mtp and biases_key in raw_mtp: + if scales_key != key and scales_key in raw_mtp and biases_key in raw_mtp: # Quantized triplet → dequantize to BF16 dq = mx.dequantize( raw_mtp[key], @@ -234,6 +264,48 @@ def __init__(self, args, n_layers): processed.add(key) del raw_mtp + # --- Convert fused expert format to split format --- + # Qwen3.6 MTP uses fused expert keys: + # layers.X.mlp.experts.gate_up_proj [n_experts, 2*intermediate, hidden] + # layers.X.mlp.experts.down_proj [n_experts, hidden, intermediate] + # but mlx_lm's DecoderLayer expects split switch_mlp keys: + # layers.X.mlp.switch_mlp.gate_proj.weight [n_experts, intermediate, hidden] + # layers.X.mlp.switch_mlp.up_proj.weight [n_experts, intermediate, hidden] + # layers.X.mlp.switch_mlp.down_proj.weight [n_experts, hidden, intermediate] + for key in list(mtp_weights.keys()): + if ".mlp.experts.gate_up_proj" in key: + prefix = key.replace(".mlp.experts.gate_up_proj", "") + w = mtp_weights.pop(key) + intermediate = w.shape[1] // 2 + gate_key = f"{prefix}.mlp.switch_mlp.gate_proj.weight" + up_key = f"{prefix}.mlp.switch_mlp.up_proj.weight" + mtp_weights[gate_key] = w[:, :intermediate, :] + mtp_weights[up_key] = w[:, intermediate:, :] + logger.info( + "[MTP inject] Split fused experts.gate_up_proj -> " + "switch_mlp.{gate_proj,up_proj}" + ) + elif ".mlp.experts.down_proj" in key: + prefix = key.replace(".mlp.experts.down_proj", "") + w = mtp_weights.pop(key) + down_key = f"{prefix}.mlp.switch_mlp.down_proj.weight" + mtp_weights[down_key] = w + logger.info( + "[MTP inject] Renamed experts.down_proj -> switch_mlp.down_proj" + ) + + # --- Fixup RMSNorm weights: HuggingFace offset convention --- + # Qwen3.5/3.6 models store RMSNorm weights as offsets (actual = 1 + stored). + # The main model's sanitize() handles this, but MTP weights bypass sanitize. + # Detect raw-offset weights (mean < 0.5) and apply +1.0; skip if already + # in actual-gamma space (as produced by add_mtp_weights_qwen35.py). + norm_fixup_count = _apply_qwen_mtp_rmsnorm_offset_fixups(mtp_weights) + if norm_fixup_count > 0: + logger.info( + f"[MTP inject] Applied +1.0 RMSNorm offset to {norm_fixup_count} " + f"norm weights (raw HF offset detected)" + ) + mtp.load_weights(list(mtp_weights.items()), strict=False) mx.eval(mtp.parameters()) diff --git a/vllm_mlx/prompt_warmup.py b/vllm_mlx/prompt_warmup.py new file mode 100644 index 000000000..7511e4878 --- /dev/null +++ b/vllm_mlx/prompt_warmup.py @@ -0,0 +1,275 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Prompt warm-up for vllm-mlx. + +At server startup, pre-populates the prefix cache by running one short +generation per warm-up prompt. The first user request that shares a prefix +with a warmed prompt sees cache-hit TTFT instead of cold prefill latency. + +File format (JSON): + [ + [{"role": "system", "content": "You are ..."}], + [{"role": "system", "content": "..."}, {"role": "user", "content": "hi"}] + ] + +Each entry is a list of chat messages — same shape as a ``/v1/chat/completions`` +``messages`` field. The warmer runs a ``max_tokens=1`` chat completion for each, +which flows through the exact same path as a real request and writes the KV +state to the prefix cache. + +Paths resolve from the current working directory. A single-message system +prompt is sufficient if that is the shared prefix. + +Sizing note: prompts are warmed concurrently via ``asyncio.gather``, so N +entries fire N concurrent prefills at startup. Each prefill allocates KV +cache for its prompt length. For typical agent deployments 1–3 entries +(one per active persona) cover the hot paths; a very large warm-prompts +file on a memory-tight model can exhaust headroom at boot. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +def load_warmup_file(path: str) -> list[list[dict[str, Any]]]: + """Load and validate a warm-up prompts JSON file. + + Raises: + FileNotFoundError: If the file does not exist. + ValueError: If the file shape is invalid. + """ + p = Path(path).expanduser() + if not p.exists(): + raise FileNotFoundError(f"Warm-up prompts file not found: {p}") + + data = json.loads(p.read_text()) + if not isinstance(data, list): + raise ValueError( + f"Warm-up file must contain a top-level JSON list, got {type(data).__name__}" + ) + + if not data: + raise ValueError(f"Warm-up file is empty: {p}") + + for i, entry in enumerate(data): + if not isinstance(entry, list) or not entry: + raise ValueError( + f"Warm-up entry {i}: expected non-empty list of message dicts" + ) + for j, msg in enumerate(entry): + if not isinstance(msg, dict): + raise ValueError( + f"Warm-up entry {i} message {j}: expected dict, got {type(msg).__name__}" + ) + if "role" not in msg or "content" not in msg: + raise ValueError( + f"Warm-up entry {i} message {j}: missing 'role' or 'content'" + ) + + return data + + +def _ensure_user_terminator(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Ensure the message list ends with a user message. + + Some chat templates (Qwen3.6, DeepSeek-VL, a handful of others) require + at least one user message or raise ``TemplateError: No user query found``. + We prefer to cache just the system prefix, but when the template won't + render without a user, append a minimal placeholder. The common prefix + up to the start of user content still matches real requests, so the + system tokens still get cached. + """ + if messages and messages[-1].get("role") in ("user", "assistant"): + return messages + return [*messages, {"role": "user", "content": " "}] + + +def _build_strict_prefix_string( + tokenizer: Any, messages: list[dict[str, Any]], enable_thinking: bool = True +) -> str | None: + """Build a STRING prefix that is a prefix of any real request's rendered + chat template for the same system and empty chat history. + + Strategy: render the chat template twice with two DIFFERENT user contents + and ``tokenize=False`` (matching what the server does). Truncate the + first output at the position where the two strings diverge — that's + where user content gets inserted. + + We return a STRING (not tokens) because the engine's request path also + applies the template with ``tokenize=False`` and then lets the tokenizer + encode the result. Going through the same pipeline guarantees the warm + entry's tokens are a strict prefix of a real request's tokens. + + This enables warm-prompts on hybrid SSM+attention models where LCP + matching is disabled (SSM state can't be trimmed) — they rely purely + on strict PREFIX match. + + Returns None if rendering fails or the two probes don't diverge past a + reasonable prefix length (unusual template). + """ + apply = getattr(tokenizer, "apply_chat_template", None) + if apply is None: + return None + + def _with_user(user_content: str) -> list[dict[str, Any]]: + msgs = [dict(m) for m in messages] + if msgs and msgs[-1].get("role") == "user": + msgs[-1] = {**msgs[-1], "content": user_content} + else: + msgs = [*msgs, {"role": "user", "content": user_content}] + return msgs + + kwargs = { + "tokenize": False, + "add_generation_prompt": True, + "enable_thinking": enable_thinking, + } + + # Multi-character probes that differ at position 0. Multi-char is safer + # than "A"/"B" against hypothetical templates that treat single-character + # content specially (e.g. stripping or wrapping). Differ-at-position-0 is + # required: probes that share a prefix (e.g. "__PROBE_A__"/"__PROBE_B__") + # would make the divergence loop overshoot into the shared-prefix region + # and cache bytes that are not in real requests. + probe_a, probe_b = "Alpha", "Bravo" + try: + a = apply(_with_user(probe_a), **kwargs) + b = apply(_with_user(probe_b), **kwargs) + except Exception: + # Template may reject enable_thinking on non-Qwen models — retry without + try: + kwargs.pop("enable_thinking", None) + a = apply(_with_user(probe_a), **kwargs) + b = apply(_with_user(probe_b), **kwargs) + except Exception: + return None + + if not isinstance(a, str) or not isinstance(b, str): + return None + + # Character-level divergence — the boundary is where user content starts. + # Track whether we actually diverged; identical probes mean the template + # ignored user content (unusual — bail out rather than cache the full + # rendering which would include whatever template fallback was used). + boundary = 0 + diverged = False + for i in range(min(len(a), len(b))): + if a[i] != b[i]: + diverged = True + break + boundary = i + 1 + + if not diverged: + return None + + # Require a reasonable prefix. Too-short means the template is unusual. + if boundary < 16: + return None + + return a[:boundary] + + +async def warm_prefix_cache( + engine: Any, + prompts: list[list[dict[str, Any]]], + *, + max_tokens: int = 1, +) -> dict[str, Any]: + """Run each prompt through the engine to populate the prefix cache. + + Prefers the strict-prefix path when the engine exposes a tokenizer: + manually tokenize with ``add_generation_prompt=False`` and feed the + raw token IDs to the engine's ``stream_generate`` (which accepts + ``prompt: str | list[int]``). Real requests — which always use + ``add_generation_prompt=True`` — will then find the warm entry as an + exact strict prefix, independent of the engine's LCP matcher. + + This is the difference between warm-prompts helping dense models only + and helping hybrid SSM+attention models too. + + Falls back to ``engine.stream_chat`` with a placeholder user message + appended if no tokenizer is exposed — strict-prefix match won't apply + there, so the feature is effectively LCP-only for that engine. + + Runs all prompts concurrently (``asyncio.gather``). + + Args: + engine: The vllm-mlx engine (exposes ``stream_chat`` and optionally + ``tokenizer`` + ``stream_generate``). + prompts: List of message arrays. + max_tokens: Tokens to generate per warm-up. 1 is enough. + + Returns: + Dict with ``count``, ``skipped``, ``elapsed_ms``, + ``total_prompt_tokens``, and ``mode`` (``"strict-prefix"`` or + ``"chat-fallback"``) describing which path was used. + """ + tokenizer = getattr(engine, "tokenizer", None) + use_strict_prefix = tokenizer is not None and hasattr(engine, "stream_generate") + + async def _one_strict( + idx: int, messages: list[dict[str, Any]] + ) -> tuple[int, int, str | None]: + prefix_str = _build_strict_prefix_string(tokenizer, messages) + if prefix_str is None: + return await _one_chat(idx, messages) + try: + async for output in engine.stream_generate( + prompt=prefix_str, + max_tokens=max_tokens, + temperature=0.0, + ): + if output.finished: + return 1, int(output.prompt_tokens or 0), None + return 0, 0, "no finished output" + except Exception as e: + err = f"{type(e).__name__}: {str(e)[:120]}" + logger.warning("[warmup] prompt %d (strict) failed: %s", idx, err) + return await _one_chat(idx, messages) + + async def _one_chat( + idx: int, messages: list[dict[str, Any]] + ) -> tuple[int, int, str | None]: + patched = _ensure_user_terminator(messages) + try: + async for output in engine.stream_chat( + messages=patched, + max_tokens=max_tokens, + temperature=0.0, + ): + if output.finished: + return 1, int(output.prompt_tokens or 0), None + return 0, 0, "no finished output" + except Exception as e: + err = f"{type(e).__name__}: {str(e)[:120]}" + logger.warning("[warmup] prompt %d (chat) failed: %s", idx, err) + return 0, 0, err + + runner = _one_strict if use_strict_prefix else _one_chat + mode = "strict-prefix" if use_strict_prefix else "chat-fallback" + + t0 = time.perf_counter() + results = await asyncio.gather( + *(runner(i, msgs) for i, msgs in enumerate(prompts)), + return_exceptions=False, + ) + elapsed_ms = (time.perf_counter() - t0) * 1000 + + completed = sum(r[0] for r in results) + total_prompt_tokens = sum(r[1] for r in results) + skipped = sum(1 for r in results if r[2] is not None) + + return { + "count": completed, + "skipped": skipped, + "elapsed_ms": elapsed_ms, + "total_prompt_tokens": total_prompt_tokens, + "mode": mode, + } diff --git a/vllm_mlx/reasoning/base.py b/vllm_mlx/reasoning/base.py index aaefef9c1..3e91f52ef 100644 --- a/vllm_mlx/reasoning/base.py +++ b/vllm_mlx/reasoning/base.py @@ -108,3 +108,19 @@ def reset_state(self): # noqa: B027 This is intentionally a default no-op implementation. """ pass + + def finalize_stream(self) -> DeltaMessage | None: # noqa: B027 + """ + Finalize streaming state at end of stream. + + Called after the last delta is processed but before the stream + closes. Parsers that buffer partial markers internally should + flush any remaining text here. + + Default implementation is a no-op (returns None). + + Returns: + DeltaMessage with any pending reasoning/content to emit, + or None if nothing to flush. + """ + return None diff --git a/vllm_mlx/reasoning/deepseek_r1_parser.py b/vllm_mlx/reasoning/deepseek_r1_parser.py index b633781de..d29664163 100644 --- a/vllm_mlx/reasoning/deepseek_r1_parser.py +++ b/vllm_mlx/reasoning/deepseek_r1_parser.py @@ -53,10 +53,7 @@ def extract_reasoning( """ # If we have end token but no start token, treat beginning as reasoning if self.end_token in model_output and self.start_token not in model_output: - reasoning, _, content = model_output.partition(self.end_token) - reasoning = reasoning.strip() or None - content = content.strip() or None - return reasoning, content + return self._extract_complete_reasoning(model_output) # If neither token, return as pure content if self.end_token not in model_output and self.start_token not in model_output: diff --git a/vllm_mlx/reasoning/gemma4_parser.py b/vllm_mlx/reasoning/gemma4_parser.py index 8b6dd8149..cc3ea3071 100644 --- a/vllm_mlx/reasoning/gemma4_parser.py +++ b/vllm_mlx/reasoning/gemma4_parser.py @@ -20,6 +20,14 @@ to transition from thinking to response mode. This parser handles both. When thinking is disabled or not triggered, output contains no tags. + +Degenerate cycling: + On long prompts with tools, Gemma 4 may oscillate between thought and + response channels many times, producing garbage reasoning before finally + emitting valid content/tool_calls. The parser handles this by splitting + at the LAST so all cycles go into reasoning_content and only + the final response goes into content. Channel tokens are stripped from + both sides. """ from .base import DeltaMessage @@ -28,6 +36,11 @@ # Channel names that follow <|channel> — stripped from output _THOUGHT_PREFIX = "thought" _RESPONSE_MARKER = "<|channel>response" +# Full "open-thinking" marker. Kept as a buffering target so that partial +# prefixes arriving mid-stream (e.g. "<|channel>th", "<|channel>thoug") are +# held back until the label completes, instead of leaking 'th', 'tho', etc. +# into the reasoning output. +_THOUGHT_MARKER = "<|channel>thought" def _strip_channel_name(text: str, prefix: str) -> str: @@ -37,6 +50,38 @@ def _strip_channel_name(text: str, prefix: str) -> str: return text.lstrip("\n") +def _strip_channel_tokens(text: str) -> str: + """Remove all channel special tokens and bare channel names from text. + + Handles degenerate model output with multiple thought/response cycles + by stripping all protocol tokens, leaving only the actual text content. + """ + # Remove special tokens + text = text.replace("", "") + text = text.replace("<|channel>", "") + # Remove channel names on standalone lines + lines = text.split("\n") + cleaned = [] + for line in lines: + s = line.strip() + if s in ("thought", "response"): + continue + cleaned.append(line) + text = "\n".join(cleaned) + # Strip leading channel name + text = text.strip() + for name in ("thought", "response"): + if text.startswith(name + "\n"): + text = text[len(name) + 1 :] + break + if text.startswith(name) and ( + len(text) == len(name) or not text[len(name)].isalpha() + ): + text = text[len(name) :] + break + return text.strip() + + class Gemma4ReasoningParser(BaseThinkingReasoningParser): """ Reasoning parser for Gemma 4 models. @@ -52,6 +97,18 @@ class Gemma4ReasoningParser(BaseThinkingReasoningParser): Output: reasoning="Let me think...", content="The answer is 42." When no tags are present, the entire output is treated as content. + + Degenerate cycling (long prompts + tools): + Uses rpartition to split at the LAST , so all intermediate + thought/response cycles go into reasoning and only the final response + goes into content. + + Streaming buffering: + Partial markers at a delta boundary (e.g. "<|channel>" without a + following "response" yet) are buffered internally so they don't + leak into reasoning/content. The buffer is either consumed when + the marker completes in a later delta, or flushed as reasoning + via finalize_stream() when the stream ends. """ @property @@ -62,6 +119,63 @@ def start_token(self) -> str: def end_token(self) -> str: return "" + def __init__(self, tokenizer=None): + super().__init__(tokenizer) + # Trailing text withheld because it could complete into a marker + # (e.g. "<|channel>" that might become "<|channel>response" next delta). + self._pending: str = "" + # Tracks whether we have emitted the first content delta past the + # <|channel>response transition — used to strip the leading newline. + self._content_seen: bool = False + + def reset_state(self): + super().reset_state() + self._pending = "" + self._content_seen = False + + def _trailing_partial_marker_len(self, text: str) -> int: + """ + Return length of trailing substring of `text` that is a proper prefix + of any transition marker (, <|channel>response, <|channel>). + + Only counts PROPER prefixes — if the marker is already complete in + `text`, no buffering is needed. Returns 0 if no partial match. + + We must never buffer legitimate content. For <|channel>, only buffer + when it appears AT THE END and is not followed by more text (i.e., + `response` or `thought` hasn't arrived yet). + """ + markers = (_RESPONSE_MARKER, _THOUGHT_MARKER, self.end_token, self.start_token) + max_len = 0 + for marker in markers: + # Scan from longest proper prefix downwards + for i in range(min(len(marker) - 1, len(text)), 0, -1): + if text.endswith(marker[:i]): + # Proper prefix match; ensure we're not inside a completed + # marker (which would already be handled by other logic). + if not text.endswith(marker): + if i > max_len: + max_len = i + break + return max_len + + def finalize_stream(self) -> DeltaMessage | None: + """ + Flush any buffered partial marker at the end of stream. + + If the stream ends while we have a partial marker buffered (e.g. + model emitted "<|channel>" as its last token and got truncated by + max_tokens), emit it as reasoning so the client doesn't lose the + text. Content phase flushes as content. + """ + if not self._pending: + return None + pending = self._pending + self._pending = "" + if self._phase == "content": + return DeltaMessage(content=pending) + return DeltaMessage(reasoning=pending) + def extract_reasoning( self, model_output: str, @@ -69,38 +183,44 @@ def extract_reasoning( """ Extract reasoning from complete output. - Handles both and <|channel>response as transition markers. - Strips channel names ("thought", "response") from output. + Uses rpartition (LAST ) to handle degenerate cycling: + all intermediate thought/response cycles go into reasoning, + only the final response goes into content. Channel tokens + are stripped from both sides. """ text = model_output - # Try standard format first: <|channel>thought...response + # Standard format: <|channel>thought...content + # Use rpartition on LAST to handle multiple cycles if self.start_token in text and self.end_token in text: _, _, after_start = text.partition(self.start_token) - reasoning, _, content = after_start.partition(self.end_token) - reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) - content = content.strip() + reasoning, _, content = after_start.rpartition(self.end_token) + reasoning = _strip_channel_tokens(reasoning) + content = _strip_channel_tokens(content) return reasoning or None, content or None - # Try alternative format: <|channel>thought...<|channel>response... + # Alternative format: <|channel>thought...<|channel>response... + # Use rfind for the LAST <|channel>response marker if text.count(self.start_token) >= 2 and _RESPONSE_MARKER in text: _, _, after_start = text.partition(self.start_token) - reasoning, _, content = after_start.partition(_RESPONSE_MARKER) - reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) - content = content.lstrip("\n").strip() + last_resp = after_start.rfind(_RESPONSE_MARKER) + reasoning = after_start[:last_resp] + content = after_start[last_resp + len(_RESPONSE_MARKER) :] + reasoning = _strip_channel_tokens(reasoning) + content = _strip_channel_tokens(content) return reasoning or None, content or None # Only closing tag (think injected in prompt) if self.end_token in text: - reasoning, _, content = text.partition(self.end_token) - reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) - content = content.strip() + reasoning, _, content = text.rpartition(self.end_token) + reasoning = _strip_channel_tokens(reasoning) + content = _strip_channel_tokens(content) return reasoning or None, content or None # Only start tag (incomplete reasoning, no end yet) if self.start_token in text: _, _, reasoning = text.partition(self.start_token) - reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + reasoning = _strip_channel_tokens(reasoning) return reasoning or None, None # No tags at all — pure content @@ -119,52 +239,142 @@ def extract_reasoning_streaming( - No tags: treat as content (Gemma 4 doesn't inject tags in prompt) - <|channel>thought: enter reasoning mode, strip channel name - or <|channel>response: transition to content mode + - Re-entry into thought from content (degenerate cycling): back to reasoning + + Partial markers at delta boundaries are buffered internally to + prevent leaking them as reasoning/content. + """ + # Buffer trailing partial marker so it doesn't leak into output. + # Process only the "safe" portion (without the partial trailing bytes). + trailing = self._trailing_partial_marker_len(current_text) + safe_current = current_text[:-trailing] if trailing else current_text + prev_trailing = self._trailing_partial_marker_len(previous_text) + safe_previous = ( + previous_text[:-prev_trailing] if prev_trailing else previous_text + ) + + # Update buffered pending text for external flush (finalize_stream). + self._pending = current_text[len(safe_current) :] + + # If no new safe text this delta, suppress emission — everything + # new is buffered as a potential marker prefix. + if len(safe_current) <= len(safe_previous): + return None + + safe_delta = safe_current[len(safe_previous) :] + + return self._extract_from_safe_text(safe_previous, safe_current, safe_delta) + + @staticmethod + def _strip_channel_tokens_from_delta( + msg: DeltaMessage | None, + ) -> DeltaMessage | None: + """Strip channel special tokens from content and reasoning in a delta.""" + if msg is None: + return None + c = msg.content + r = msg.reasoning + if c is not None: + c = c.replace("", "").replace("<|channel>", "") + if r is not None: + r = r.replace("", "").replace("<|channel>", "") + if not c and not r: + return None + if c == msg.content and r == msg.reasoning: + return msg + return DeltaMessage(reasoning=r or None, content=c or None) + + def _extract_from_safe_text( + self, + previous_text: str, + current_text: str, + delta_text: str, + ) -> DeltaMessage | None: + """Parse safe (non-buffered) text. + + Uses count-based detection for channel tokens so that multiple + thought/response cycles (degenerate model behaviour) are handled + correctly — each NEW <|channel> re-enters reasoning, each NEW + transitions to content. """ # No channel tokens at all — plain content if self.start_token not in current_text and self.end_token not in current_text: return DeltaMessage(content=delta_text) - # Check for alternative transition: <|channel>response - if _RESPONSE_MARKER in current_text: - if _RESPONSE_MARKER not in previous_text: - # Transition happening in this delta - # Find what (if any) content comes after the marker - marker_pos = current_text.find(_RESPONSE_MARKER) - after_marker = current_text[marker_pos + len(_RESPONSE_MARKER) :] - after_marker = after_marker.lstrip("\n") - if after_marker: - return DeltaMessage(content=after_marker) - return None # Suppress the marker itself - else: - # Already past transition — pure content - # But we need to only emit the NEW text (delta) - return DeltaMessage(content=delta_text) - - # Delegate to base class for standard <|channel>/ handling - result = super().extract_reasoning_streaming( - previous_text, current_text, delta_text - ) + # ── Alternative transition: <|channel>response ── + # Check BEFORE generic <|channel> count — <|channel>response contains + # <|channel> but is a content transition, not a re-entry to reasoning. + if _RESPONSE_MARKER in current_text and _RESPONSE_MARKER not in previous_text: + self._phase = "content" + self._content_seen = False + marker_pos = current_text.find(_RESPONSE_MARKER) + after_marker = current_text[marker_pos + len(_RESPONSE_MARKER) :] + after_marker = after_marker.lstrip("\n") + if after_marker: + self._content_seen = True + return self._strip_channel_tokens_from_delta( + DeltaMessage(content=after_marker) + ) + return None + + cur_starts = current_text.count(self.start_token) + prev_starts = previous_text.count(self.start_token) + cur_ends = current_text.count(self.end_token) + prev_ends = previous_text.count(self.end_token) + + # ── NEW <|channel> (enter / re-enter reasoning) ── + if cur_starts > prev_starts: + if self._phase != "thinking": + self._phase = "thinking" + self._content_seen = False + return None # suppress marker + channel name + + # ── NEW (transition to content) ── + if cur_ends > prev_ends: + self._phase = "content" + self._content_seen = False + # Text after the last in this delta is content + last_end = delta_text.rfind(self.end_token) + if last_end >= 0: + after = delta_text[last_end + len(self.end_token) :] + after = _strip_channel_name(after.lstrip("\n"), _THOUGHT_PREFIX) + after = _strip_channel_name(after, "response") + if after: + self._content_seen = True + return DeltaMessage(content=after) + return None + + # ── Content phase ── + if self._phase == "content": + if not self._content_seen: + stripped = delta_text.lstrip("\n") + stripped = _strip_channel_name(stripped, _THOUGHT_PREFIX) + stripped = _strip_channel_name(stripped, "response") + self._content_seen = bool(stripped) + if not stripped: + return None + return self._strip_channel_tokens_from_delta( + DeltaMessage(content=stripped) + ) + return self._strip_channel_tokens_from_delta( + DeltaMessage(content=delta_text) + ) - # Strip "thought" channel name from initial reasoning - if result is not None and result.reasoning is not None: - r = result.reasoning - # First reasoning delta after <|channel> will be "thought" or "thought\n" - if self.start_token in current_text: - # Check if this is the very first reasoning content - after_channel = current_text.split(self.start_token, 1)[1] - if after_channel.startswith(_THOUGHT_PREFIX): - # Remove "thought" prefix from the accumulated reasoning so far - clean = after_channel[len(_THOUGHT_PREFIX) :].lstrip("\n") - # Compute what portion of clean text is in this delta + # ── Thinking phase: emit as reasoning ── + if self._phase == "thinking": + # Strip "thought" channel name from initial reasoning delta + if cur_starts > 0: + after_ch = current_text.split(self.start_token, 1)[1] + if after_ch.startswith(_THOUGHT_PREFIX): + clean = after_ch[len(_THOUGHT_PREFIX) :].lstrip("\n") prev_after = "" if self.start_token in previous_text: prev_after = previous_text.split(self.start_token, 1)[1] if prev_after.startswith(_THOUGHT_PREFIX): prev_after = prev_after[len(_THOUGHT_PREFIX) :].lstrip("\n") - # The new reasoning text is clean minus what was already emitted - new_reasoning = clean[len(prev_after) :] - if new_reasoning: - return DeltaMessage(reasoning=new_reasoning) - return None # Suppress channel name token + r = clean[len(prev_after) :] + return DeltaMessage(reasoning=r) if r else None + return DeltaMessage(reasoning=delta_text) if delta_text else None - return result + # ── pre_think: first delta has no markers yet — emit as reasoning ── + return DeltaMessage(reasoning=delta_text) if delta_text else None diff --git a/vllm_mlx/reasoning/think_parser.py b/vllm_mlx/reasoning/think_parser.py index 8596dafbf..4c7f9719a 100644 --- a/vllm_mlx/reasoning/think_parser.py +++ b/vllm_mlx/reasoning/think_parser.py @@ -55,10 +55,14 @@ def __init__(self, tokenizer=None): super().__init__(tokenizer) # Streaming state — reset per request via reset_state() self._phase: str = "pre_think" # "pre_think" | "thinking" | "content" + self._content_started = False + self._content_buffer = "" def reset_state(self): """Reset state machine for a new streaming request.""" self._phase = "pre_think" + self._content_started = False + self._content_buffer = "" def extract_reasoning( self, @@ -80,21 +84,12 @@ def extract_reasoning( """ text = model_output - # Case 1: Both tags present (normal case) - if self.start_token in text and self.end_token in text: - _, _, after_start = text.partition(self.start_token) - reasoning, _, content = after_start.partition(self.end_token) - # Strip duplicate end tokens (some models generate ) - while content.lstrip().startswith(self.end_token): - content = content.lstrip()[len(self.end_token) :] - return reasoning.strip() or None, content.strip() or None - - # Case 2: Only closing tag (think was injected in prompt) + # Cases 1 and 2: consume one or more leading reasoning spans. Some + # thinking models emit an extra empty ```` block after + # the forced transition; that block still belongs to reasoning, not + # final content. if self.end_token in text: - reasoning, _, content = text.partition(self.end_token) - while content.lstrip().startswith(self.end_token): - content = content.lstrip()[len(self.end_token) :] - return reasoning.strip() or None, content.strip() or None + return self._extract_complete_reasoning(text) # Case 3: Only start tag (incomplete reasoning, no end yet) if self.start_token in text: @@ -153,12 +148,7 @@ def extract_reasoning_streaming( eidx = after.find(end_tok) reasoning = after[:eidx] content = after[eidx + len(end_tok) :] - if not reasoning and not content: - return None - return DeltaMessage( - reasoning=reasoning or None, - content=content or None, - ) + return self._transition_to_content(reasoning, content) return DeltaMessage(reasoning=after) if after else None # Implicit mode: completed without an explicit . @@ -171,12 +161,7 @@ def extract_reasoning_streaming( else: reasoning = None content = delta_text - if not reasoning and not content: - return None - return DeltaMessage( - reasoning=reasoning or None, - content=content or None, - ) + return self._transition_to_content(reasoning, content) # No tags — default to reasoning (implicit mode assumption). # If the model doesn't use thinking at all, the server's @@ -196,14 +181,109 @@ def extract_reasoning_streaming( else: reasoning = delta_text content = None - if not reasoning and not content: - return None - return DeltaMessage( - reasoning=reasoning or None, - content=content or None, - ) + return self._transition_to_content(reasoning, content) return DeltaMessage(reasoning=delta_text) # ── Phase: content ──────────────────────────────────────── # Past the reasoning block — everything is content. - return DeltaMessage(content=delta_text) + return self._content_delta(delta_text) + + def _extract_complete_reasoning(self, text: str) -> tuple[str | None, str | None]: + """Split complete output into leading reasoning spans and final content.""" + reasoning_parts: list[str] = [] + remainder = text + + while remainder: + stripped = remainder.lstrip() + + if stripped.startswith(self.start_token): + after_start = stripped[len(self.start_token) :] + reasoning, found, after_end = after_start.partition(self.end_token) + if not found: + reasoning_parts.append(reasoning) + remainder = "" + break + if reasoning.strip(): + reasoning_parts.append(reasoning.strip()) + remainder = after_end + continue + + start_idx = stripped.find(self.start_token) + end_idx = stripped.find(self.end_token) + if end_idx != -1 and (start_idx == -1 or end_idx < start_idx): + reasoning = stripped[:end_idx] + if reasoning.strip(): + reasoning_parts.append(reasoning.strip()) + remainder = stripped[end_idx + len(self.end_token) :] + continue + + remainder = stripped + break + + reasoning = "\n".join(reasoning_parts).strip() or None + content = remainder.strip() or None + return reasoning, content + + def _transition_to_content( + self, reasoning: str | None, content: str | None + ) -> DeltaMessage | None: + """Return a delta while suppressing leading post-transition think blocks.""" + content_msg = self._content_delta(content or "") + extra_reasoning = content_msg.reasoning if content_msg else None + final_content = content_msg.content if content_msg else None + reasoning_text = (reasoning or "") + (extra_reasoning or "") + if not reasoning_text and not final_content: + return None + return DeltaMessage( + reasoning=reasoning_text or None, + content=final_content or None, + ) + + def _content_delta(self, delta_text: str) -> DeltaMessage | None: + """Emit content after consuming repeated leading think blocks.""" + if not delta_text and not self._content_buffer: + return None + + if self._content_started: + return DeltaMessage(content=delta_text) if delta_text else None + + self._content_buffer += delta_text + buffer = self._content_buffer.lstrip() + reasoning_parts: list[str] = [] + + while buffer: + if buffer.startswith(self.end_token): + buffer = buffer[len(self.end_token) :].lstrip() + continue + + if buffer.startswith(self.start_token): + after_start = buffer[len(self.start_token) :] + end_idx = after_start.find(self.end_token) + if end_idx == -1: + self._content_buffer = buffer + return None + reasoning = after_start[:end_idx] + if reasoning: + reasoning_parts.append(reasoning) + buffer = after_start[end_idx + len(self.end_token) :].lstrip() + continue + + if self.start_token.startswith(buffer): + self._content_buffer = buffer + return None + + if self.end_token.startswith(buffer): + self._content_buffer = buffer + return None + + self._content_started = True + self._content_buffer = "" + return DeltaMessage( + reasoning="".join(reasoning_parts) or None, + content=buffer, + ) + + self._content_buffer = "" + if reasoning_parts: + return DeltaMessage(reasoning="".join(reasoning_parts)) + return None diff --git a/vllm_mlx/request.py b/vllm_mlx/request.py index f18b238d8..b61a12171 100644 --- a/vllm_mlx/request.py +++ b/vllm_mlx/request.py @@ -9,7 +9,7 @@ import enum import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union if TYPE_CHECKING: from .paged_cache import BlockTable @@ -61,6 +61,10 @@ class SamplingParams: repetition_penalty: float = 1.0 stop: Optional[List[str]] = None stop_token_ids: Optional[List[int]] = None + # Extra per-request logits processors (e.g. JSON schema constrained + # decoding via ``lm-format-enforcer``). These are merged with any + # built-in processors (repetition/presence penalty) at batch time. + logits_processors: Optional[List[Callable]] = None def __post_init__(self): if self.stop is None: diff --git a/vllm_mlx/rerank.py b/vllm_mlx/rerank.py new file mode 100644 index 000000000..596a18890 --- /dev/null +++ b/vllm_mlx/rerank.py @@ -0,0 +1,384 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Reranker engine for cross-encoder models. + +Provides a dedicated RerankEngine with adapter-based scoring for the +OpenAI/Jina-compatible /v1/rerank endpoint. Cross-encoder models use +AutoModelForSequenceClassification-style loading, not mlx_lm.load. +""" + +import logging +import math +import time +from abc import ABC, abstractmethod + +import asyncio + +import mlx.core as mx + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Adapter contract +# ============================================================================= + + +class RerankAdapter(ABC): + """ + Per-family adapter for reranker models. + + Different cross-encoder families use different tokenization patterns, + score extraction logic, and normalization functions. This contract + isolates those differences so RerankEngine stays family-agnostic. + """ + + @abstractmethod + def tokenize_pair(self, tokenizer, query: str, document: str) -> dict: + """ + Tokenize a (query, document) pair for the cross-encoder. + + Args: + tokenizer: The HuggingFace tokenizer instance. + query: The query string. + document: The document string. + + Returns: + Dict with 'input_ids' and 'attention_mask' as numpy arrays. + """ + ... + + @abstractmethod + def extract_score(self, logits) -> float: + """ + Extract a raw relevance score from model output logits. + + Args: + logits: Model output logits (list or array), shape varies by model. + + Returns: + A single float raw score. + """ + ... + + @abstractmethod + def normalize(self, raw_score: float) -> float: + """ + Normalize a raw score to [0, 1] range. + + Args: + raw_score: The raw score from extract_score(). + + Returns: + Normalized relevance score in [0, 1]. + """ + ... + + +class SigmoidAdapter(RerankAdapter): + """ + Default adapter for single-logit sigmoid rerankers. + + Works with Jina Reranker v2, BGE Reranker v2, and MS-MARCO MiniLM + families. These models output a single relevance logit at position 0, + normalized via sigmoid. + """ + + def tokenize_pair(self, tokenizer, query: str, document: str) -> dict: + """Tokenize as a sentence pair (query, document).""" + return tokenizer( + query, + document, + padding=True, + truncation=True, + max_length=512, + return_tensors="np", + ) + + def extract_score(self, logits) -> float: + """Extract the first logit as the relevance score.""" + return float(logits[0]) + + def normalize(self, raw_score: float) -> float: + """Apply sigmoid normalization.""" + return 1.0 / (1.0 + math.exp(-raw_score)) + + +# ============================================================================= +# Engine +# ============================================================================= + +# Default adapter registry — extend for new model families. +_ADAPTER_REGISTRY: dict[str, type[RerankAdapter]] = { + "default": SigmoidAdapter, +} + + +def get_adapter(model_name: str) -> RerankAdapter: + """ + Return the appropriate adapter for a model. + + Falls back to SigmoidAdapter (works for Jina, BGE, MS-MARCO families). + Extend _ADAPTER_REGISTRY for families that need different scoring. + """ + # Future: inspect model config to select adapter automatically. + # For now, all known MLX reranker models use the sigmoid pattern. + return _ADAPTER_REGISTRY["default"]() + + +class RerankEngine: + """ + Reranker engine for cross-encoder sequence classification models. + + Loads cross-encoder models via transformers + MLX (safetensors weights). + Scores (query, document) pairs using the adapter contract for + family-specific tokenization, score extraction, and normalization. + + Supports token-budget batching to avoid OOM on large document lists. + """ + + def __init__( + self, + model_name: str, + token_budget: int = 4096, + max_concurrency: int = 1, + ): + self.model_name = model_name + self.token_budget = token_budget + self.max_concurrency = max_concurrency + self._semaphore = asyncio.Semaphore(max_concurrency) + self._model = None + self._tokenizer = None + self._adapter: RerankAdapter | None = None + + @property + def is_loaded(self) -> bool: + return self._model is not None + + def load(self) -> None: + """ + Load the cross-encoder model and tokenizer. + + Uses transformers AutoTokenizer and loads MLX weights from safetensors + via the model's from_pretrained or equivalent MLX loading path. + """ + from transformers import AutoTokenizer + + logger.info(f"Loading reranker model: {self.model_name}") + start = time.perf_counter() + + self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self._model = self._load_mlx_model(self.model_name) + self._adapter = get_adapter(self.model_name) + + elapsed = time.perf_counter() - start + logger.info(f"Reranker model loaded in {elapsed:.2f}s: {self.model_name}") + + @staticmethod + def _load_mlx_model(model_name: str): + """ + Load an MLX cross-encoder model from HuggingFace Hub. + + Attempts mlx-community weights first (safetensors), then falls back + to transformers AutoModelForSequenceClassification with MLX conversion. + """ + try: + from huggingface_hub import snapshot_download + from safetensors import safe_open + + model_path = snapshot_download(model_name) + + import glob + import json + import os + + # Load model config + config_path = os.path.join(model_path, "config.json") + with open(config_path) as f: + config = json.load(f) + + # Load weights from safetensors + weight_files = glob.glob(os.path.join(model_path, "*.safetensors")) + if not weight_files: + raise FileNotFoundError(f"No safetensors files found in {model_path}") + + weights = {} + for wf in weight_files: + with safe_open(wf, framework="numpy") as f: + for key in f.keys(): + weights[key] = mx.array(f.get_tensor(key)) + + # Build model based on architecture + model_type = config.get("model_type", "") + num_labels = config.get("num_labels", 1) + + model = _build_classifier_model(model_type, config, weights, num_labels) + mx.eval(model.parameters()) + return model + + except Exception as e: + logger.error(f"Failed to load reranker model: {e}") + raise + + def _ensure_loaded(self) -> None: + if not self.is_loaded: + self.load() + + def score_pairs(self, query: str, documents: list[str]) -> tuple[list[float], int]: + """ + Score each (query, document) pair and return normalized relevance scores. + + Pairs are batched by token budget to control memory usage. Each batch + is tokenized together and scored in a single forward pass. + Returns (scores, total_tokens) where total_tokens reflects the + actual tokenization used for scoring (consistent with adapter). + + Args: + query: The query string. + documents: List of document strings. + + Returns: + List of normalized relevance scores, one per document, + in the same order as the input documents. + """ + self._ensure_loaded() + + # Tokenize each pair individually to measure token counts + pair_encodings = [] + pair_token_counts = [] + for doc in documents: + enc = self._adapter.tokenize_pair(self._tokenizer, query, doc) + pair_encodings.append(enc) + seq_len = ( + len(enc["input_ids"][0]) + if hasattr(enc["input_ids"][0], "__len__") + else enc["input_ids"].shape[1] + ) + pair_token_counts.append(seq_len) + + # Build batches by token budget + batches = [] + current_batch = [] + current_tokens = 0 + for i, (enc, tok_count) in enumerate(zip(pair_encodings, pair_token_counts)): + if current_batch and current_tokens + tok_count > self.token_budget: + batches.append(current_batch) + current_batch = [] + current_tokens = 0 + current_batch.append((i, enc)) + current_tokens += tok_count + if current_batch: + batches.append(current_batch) + + # Score each batch + all_scores: list[tuple[int, float]] = [] + for batch in batches: + if len(batch) == 1: + # Single pair — use encoding directly + idx, enc = batch[0] + input_ids = mx.array(enc["input_ids"]) + attention_mask = mx.array(enc["attention_mask"]) + else: + # Pad and stack multiple pairs + max_len = max( + ( + len(enc["input_ids"][0]) + if hasattr(enc["input_ids"][0], "__len__") + else enc["input_ids"].shape[1] + ) + for _, enc in batch + ) + padded_ids = [] + padded_mask = [] + for _, enc in batch: + raw_ids = enc["input_ids"][0] + ids = ( + raw_ids.tolist() + if hasattr(raw_ids, "tolist") + else list(raw_ids) + ) + raw_mask = enc["attention_mask"][0] + mask = ( + raw_mask.tolist() + if hasattr(raw_mask, "tolist") + else list(raw_mask) + ) + pad_len = max_len - len(ids) + padded_ids.append(ids + [0] * pad_len) + padded_mask.append(mask + [0] * pad_len) + input_ids = mx.array(padded_ids) + attention_mask = mx.array(padded_mask) + + output = self._model(input_ids, attention_mask=attention_mask) + logits_list = output.logits.tolist() + + for j, (idx, _enc) in enumerate(batch): + logits_row = logits_list[j] if len(batch) > 1 else logits_list[0] + raw_score = self._adapter.extract_score(logits_row) + normalized = self._adapter.normalize(raw_score) + all_scores.append((idx, normalized)) + + # Sort by original index to restore input order + all_scores.sort(key=lambda x: x[0]) + total_tokens = sum(pair_token_counts) + return [score for _, score in all_scores], total_tokens + + +def _build_classifier_model(model_type, config, weights, num_labels): + """ + Build an MLX sequence classification model from config and weights. + + This is a thin wrapper that constructs the appropriate encoder + architecture with a classification head on top. + """ + # Import here to avoid top-level dependency on specific model implementations + return _MLXClassifierWrapper(config, weights, num_labels) + + +class _MLXClassifierWrapper: + """ + Minimal MLX wrapper for sequence classification models. + + Wraps loaded safetensors weights into a callable that returns + logits for (input_ids, attention_mask) pairs. Supports BERT-family + and XLM-RoBERTa-family architectures commonly used as cross-encoders. + """ + + def __init__(self, config: dict, weights: dict, num_labels: int): + self.config = config + self.weights = weights + self.num_labels = num_labels + self._params = list(weights.values()) + + def parameters(self): + """Return model parameters for mx.eval.""" + return self._params + + def __call__(self, input_ids: mx.array, attention_mask: mx.array = None): + """ + Forward pass through the classifier. + + For encoder-only cross-encoders, this runs the full transformer + encoder and classification head. The exact layer wiring depends + on the model architecture. + + This initial implementation uses a weight-lookup forward pass + that works for standard BERT/XLM-RoBERTa classifiers. For + models with non-standard architectures, register a custom + adapter via _ADAPTER_REGISTRY. + """ + # Use the transformers-style weight naming convention + # to walk through embeddings -> encoder layers -> classifier + from vllm_mlx.rerank_forward import classifier_forward + + logits = classifier_forward( + input_ids, attention_mask, self.weights, self.config + ) + return _ClassifierOutput(logits=logits) + + +class _ClassifierOutput: + """Simple container for classifier output logits.""" + + def __init__(self, logits: mx.array): + self.logits = logits diff --git a/vllm_mlx/rerank_forward.py b/vllm_mlx/rerank_forward.py new file mode 100644 index 000000000..9f98cad4e --- /dev/null +++ b/vllm_mlx/rerank_forward.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +MLX forward pass for BERT-family sequence classification models. + +Implements a from-weights forward pass for cross-encoder rerankers +that use the standard BERT/XLM-RoBERTa architecture with a +classification head. This avoids pulling in the full transformers +modeling stack at inference time — only the tokenizer is needed +from transformers. +""" + +import mlx.core as mx +import mlx.nn as nn + + +def classifier_forward( + input_ids: mx.array, + attention_mask: mx.array, + weights: dict[str, mx.array], + config: dict, +) -> mx.array: + """ + Run a BERT-family classifier forward pass on MLX. + + Args: + input_ids: (batch, seq_len) token IDs. + attention_mask: (batch, seq_len) attention mask (1=attend, 0=pad). + weights: Dict mapping weight name -> mx.array. + config: Model config dict (from config.json). + + Returns: + logits: (batch, num_labels) classification logits. + """ + hidden_size = config["hidden_size"] + num_heads = config["num_attention_heads"] + num_layers = config["num_hidden_layers"] + num_labels = config.get("num_labels", 1) + eps = config.get("layer_norm_eps", 1e-12) + + head_dim = hidden_size // num_heads + + # Detect weight prefix (bert.* vs roberta.* vs xlm-roberta.*) + prefix = _detect_prefix(weights) + + # --- Embeddings --- + word_emb = weights[f"{prefix}.embeddings.word_embeddings.weight"] + pos_emb = weights[f"{prefix}.embeddings.position_embeddings.weight"] + tok_type_emb = weights[f"{prefix}.embeddings.token_type_embeddings.weight"] + ln_w = weights[f"{prefix}.embeddings.LayerNorm.weight"] + ln_b = weights[f"{prefix}.embeddings.LayerNorm.bias"] + + batch_size, seq_len = input_ids.shape + position_ids = mx.arange(seq_len)[None, :] # (1, seq_len) + token_type_ids = mx.zeros_like(input_ids) + + hidden = word_emb[input_ids] + pos_emb[position_ids] + tok_type_emb[token_type_ids] + hidden = _layer_norm(hidden, ln_w, ln_b, eps) + + # --- Encoder layers --- + # Build causal-free attention mask: (batch, 1, 1, seq_len) + if attention_mask is not None: + ext_mask = attention_mask[:, None, None, :].astype(mx.float32) + ext_mask = (1.0 - ext_mask) * -1e9 + else: + ext_mask = None + + for i in range(num_layers): + lp = f"{prefix}.encoder.layer.{i}" + hidden = _encoder_layer( + hidden, ext_mask, weights, lp, num_heads, head_dim, eps, config + ) + + # --- Pooler (CLS token) --- + cls_hidden = hidden[:, 0, :] # (batch, hidden_size) + pooler_w = weights.get(f"{prefix}.pooler.dense.weight") + pooler_b = weights.get(f"{prefix}.pooler.dense.bias") + if pooler_w is not None: + pooled = mx.tanh(cls_hidden @ pooler_w.T + pooler_b) + else: + pooled = cls_hidden + + # --- Classifier head --- + clf_w = weights["classifier.weight"] + clf_b = weights["classifier.bias"] + logits = pooled @ clf_w.T + clf_b # (batch, num_labels) + + return logits + + +def _detect_prefix(weights: dict) -> str: + """Detect the model weight prefix (bert, roberta, xlm-roberta).""" + for key in weights: + if key.startswith("bert."): + return "bert" + if key.startswith("roberta."): + return "roberta" + if key.startswith("xlm-roberta."): + return "xlm-roberta" + # Default to bert + return "bert" + + +def _layer_norm(x: mx.array, weight: mx.array, bias: mx.array, eps: float) -> mx.array: + """Apply layer normalization.""" + mean = mx.mean(x, axis=-1, keepdims=True) + var = mx.var(x, axis=-1, keepdims=True) + return weight * (x - mean) / mx.sqrt(var + eps) + bias + + +def _encoder_layer( + hidden: mx.array, + ext_mask: mx.array | None, + weights: dict, + prefix: str, + num_heads: int, + head_dim: int, + eps: float, + config: dict, +) -> mx.array: + """Run one BERT encoder layer (self-attention + FFN).""" + hidden_size = num_heads * head_dim + + # --- Self-attention --- + q_w = weights[f"{prefix}.attention.self.query.weight"] + q_b = weights[f"{prefix}.attention.self.query.bias"] + k_w = weights[f"{prefix}.attention.self.key.weight"] + k_b = weights[f"{prefix}.attention.self.key.bias"] + v_w = weights[f"{prefix}.attention.self.value.weight"] + v_b = weights[f"{prefix}.attention.self.value.bias"] + + batch_size, seq_len, _ = hidden.shape + + q = ( + (hidden @ q_w.T + q_b) + .reshape(batch_size, seq_len, num_heads, head_dim) + .transpose(0, 2, 1, 3) + ) + k = ( + (hidden @ k_w.T + k_b) + .reshape(batch_size, seq_len, num_heads, head_dim) + .transpose(0, 2, 1, 3) + ) + v = ( + (hidden @ v_w.T + v_b) + .reshape(batch_size, seq_len, num_heads, head_dim) + .transpose(0, 2, 1, 3) + ) + + scale = head_dim**-0.5 + attn_scores = (q @ k.transpose(0, 1, 3, 2)) * scale # (batch, heads, seq, seq) + + if ext_mask is not None: + attn_scores = attn_scores + ext_mask + + attn_probs = mx.softmax(attn_scores, axis=-1) + attn_out = ( + (attn_probs @ v).transpose(0, 2, 1, 3).reshape(batch_size, seq_len, hidden_size) + ) + + # Attention output projection + residual + LayerNorm + ao_w = weights[f"{prefix}.attention.output.dense.weight"] + ao_b = weights[f"{prefix}.attention.output.dense.bias"] + ao_ln_w = weights[f"{prefix}.attention.output.LayerNorm.weight"] + ao_ln_b = weights[f"{prefix}.attention.output.LayerNorm.bias"] + + attn_out = attn_out @ ao_w.T + ao_b + hidden = _layer_norm(hidden + attn_out, ao_ln_w, ao_ln_b, eps) + + # --- FFN --- + inter_w = weights[f"{prefix}.intermediate.dense.weight"] + inter_b = weights[f"{prefix}.intermediate.dense.bias"] + out_w = weights[f"{prefix}.output.dense.weight"] + out_b = weights[f"{prefix}.output.dense.bias"] + out_ln_w = weights[f"{prefix}.output.LayerNorm.weight"] + out_ln_b = weights[f"{prefix}.output.LayerNorm.bias"] + + intermediate = hidden @ inter_w.T + inter_b + intermediate = _gelu(intermediate) + ffn_out = intermediate @ out_w.T + out_b + hidden = _layer_norm(hidden + ffn_out, out_ln_w, out_ln_b, eps) + + return hidden + + +def _gelu(x: mx.array) -> mx.array: + """GELU activation (exact form).""" + return nn.gelu(x) diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index 0158140ee..10256934f 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -24,6 +24,7 @@ from .memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig from .paged_cache import PagedCacheManager +from .ssd_cache import SSDCacheConfig, SSDCacheTier from .prefix_cache import BlockAwarePrefixCache, PrefixCacheManager from .request import Request, RequestOutput, RequestStatus, SamplingParams from .utils.mamba_cache import ensure_mamba_support @@ -80,6 +81,9 @@ class SchedulerConfig: kv_cache_quantization_group_size: int = 64 kv_cache_min_quantize_tokens: int = 256 + # TurboQuant KV cache compression (4.6x at 3-bit, replaces standard quantization) + turbo_kv_bits: Optional[int] = None # 1-4 bit; None = disabled + # Paged cache settings (experimental - for memory efficiency) use_paged_cache: bool = ( False # Use BlockAwarePrefixCache instead of PrefixCacheManager @@ -98,6 +102,10 @@ class SchedulerConfig: # 0 = disabled. Only effective when chunked_prefill_tokens > 0. mid_prefill_save_interval: int = 8192 + # SSD cache tiering + ssd_cache_dir: Optional[str] = None # None = disabled + ssd_cache_max_gb: float = 10.0 + # MTP (Multi-Token Prediction) settings # Uses the model's built-in MTP head to predict multiple tokens per step enable_mtp: bool = False @@ -129,6 +137,33 @@ class SchedulerOutput: has_work: bool = False +def _install_prompt_cache_save(batch_gen: "BatchGenerator", prompt_cache_save) -> None: + """Monkey-patch ``_process_prompts`` to capture prompt-only cache state. + + Can be installed independently of chunked prefill. If chunked prefill is + also installed, *it* takes over ``_process_prompts`` and invokes the + callback itself, so call this **before** ``_install_chunked_prefill``. + """ + _orig_process_prompts = batch_gen._process_prompts + + try: + from mlx_lm.generate import Batch as _batch_cls + except ImportError: + _batch_cls = None # extract_cache fallback handled in patched fn + + def _patched_process_prompts(prompts, _self=batch_gen): + batch = _orig_process_prompts(prompts) + for e, uid in enumerate(batch.uids): + if batch.num_tokens[e] == 0: + try: + prompt_cache_save(uid, batch.extract_cache(e)) + except Exception: + pass + return batch + + batch_gen._process_prompts = _patched_process_prompts + + def _install_chunked_prefill( batch_gen: "BatchGenerator", budget: int, @@ -1059,9 +1094,16 @@ def _mtp_next(self=batch_gen): batch_gen._step = _mtp_step batch_gen._next = _mtp_next + if num_draft_tokens != 1: + logger.warning( + "[MTP] num_draft_tokens=%d requested, but the current batched MTP " + "path drafts exactly one token per verify step", + num_draft_tokens, + ) mode_str = "optimistic (no verify)" if optimistic else "always-advance" logger.info( - f"[MTP] installed with num_draft_tokens={num_draft_tokens}, " f"{mode_str} mode" + f"[MTP] installed with num_draft_tokens={num_draft_tokens}, " + f"effective_draft_tokens=1, {mode_str} mode" ) @@ -1147,6 +1189,7 @@ def __init__( kv_bits=self.config.kv_cache_quantization_bits, kv_group_size=self.config.kv_cache_quantization_group_size, kv_min_quantize_tokens=self.config.kv_cache_min_quantize_tokens, + turbo_kv_bits=self.config.turbo_kv_bits, ) self.memory_aware_cache = MemoryAwarePrefixCache( model=model, @@ -1156,6 +1199,22 @@ def __init__( f"Memory-aware cache enabled: " f"limit={self.memory_aware_cache.memory_limit_mb:.1f}MB" ) + + # Attach SSD tier if configured + self._ssd_tier: Optional[SSDCacheTier] = None + if self.config.ssd_cache_dir is not None: + ssd_config = SSDCacheConfig( + cache_dir=self.config.ssd_cache_dir, + max_size_gb=self.config.ssd_cache_max_gb, + ) + self._ssd_tier = SSDCacheTier(ssd_config) + self._ssd_tier.start_writer() + self._ssd_tier.reconcile() + self.memory_aware_cache.set_ssd_tier(self._ssd_tier) + logger.info( + f"SSD cache tier enabled: dir={self.config.ssd_cache_dir}, " + f"max={self.config.ssd_cache_max_gb}GB" + ) else: # Use legacy entry-count based prefix cache self.prefix_cache = PrefixCacheManager( @@ -1271,11 +1330,14 @@ def _prefill_progress(progress_list): # monkey-patch. Not a BatchGenerator constructor parameter. bg.prompt_progress_callback = _prefill_progress - # Install chunked prefill when explicitly configured OR when - # memory-aware cache is active (needed for prefix_boundary saves - # in agentic multi-turn workloads with hybrid Mamba+Transformer models). + # Install chunked prefill only when explicitly configured. + # memory_aware_cache fetch/store works independently; the mid-prefill + # save callback is an optimisation, not a requirement. + # When chunked_prefill_tokens == 0 (the default), honour the user's + # intent — do NOT silently re-enable chunked prefill just because + # memory_aware_cache is active (see #178). chunked_budget = self.config.chunked_prefill_tokens - need_chunked = chunked_budget > 0 or self.memory_aware_cache is not None + need_chunked = chunked_budget > 0 # The chunked prefill monkey-patch relies on BatchGenerator internals # (_process_prompts, active_batch, _step, etc.) that were refactored @@ -1284,20 +1346,19 @@ def _prefill_progress(progress_list): bg, "active_batch" ) + prompt_cache_cb = None + if self.memory_aware_cache is not None: + prompt_cache_cb = self._make_prompt_cache_save_callback() + if need_chunked and chunked_compatible: - if chunked_budget <= 0: - # No explicit budget — use a very large value so normal - # prompts pass through unchanged. Prefix boundary splits - # still trigger via _needs_boundary_split. - chunked_budget = 999_999 + # Full chunked prefill with mid-prefill saves and prompt cache + # save wired through the chunked next() and _process_prompts + # monkey-patches inside _install_chunked_prefill. mid_prefill_cb = None save_interval = self.config.mid_prefill_save_interval if save_interval > 0 and self.memory_aware_cache is not None: mid_prefill_cb = self._make_mid_prefill_save_callback(save_interval) logger.info(f"[mid_prefill_cache] enabled, interval={save_interval}") - prompt_cache_cb = None - if self.memory_aware_cache is not None: - prompt_cache_cb = self._make_prompt_cache_save_callback() _install_chunked_prefill( bg, chunked_budget, @@ -1314,6 +1375,14 @@ def _prefill_progress(progress_list): "check compatibility." ) + # When chunked prefill is off but memory_aware_cache is active, + # install the lightweight _process_prompts hook so prompt-only + # cache entries are still captured. This is the only safe capture + # point for hybrid Mamba+Transformer models (#178). + if not need_chunked and prompt_cache_cb is not None: + if hasattr(bg, "_process_prompts"): + _install_prompt_cache_save(bg, prompt_cache_cb) + # Install MTP if the model supports it if self.config.enable_mtp: if hasattr(self.model, "mtp") and self.model.mtp is not None: @@ -1753,6 +1822,14 @@ def add_request(self, request: Request) -> None: f"prompt_tokens={len(request.prompt_token_ids)} " f"time={_fetch_dt:.3f}s entries={len(self.memory_aware_cache._entries)}" ) + # Check SSD tier for cold-tier hit + if hasattr(self, "_ssd_tier") and self._ssd_tier is not None: + ssd_candidate = self.memory_aware_cache.check_ssd( + request.prompt_token_ids + ) + if ssd_candidate is not None: + request.cache_hit_type = "ssd_pending" + request._ssd_candidate = ssd_candidate elif self.prefix_cache is not None: # Use legacy prefix cache cache, remaining = self.prefix_cache.fetch_cache(request.prompt_token_ids) @@ -1881,6 +1958,12 @@ def _schedule_waiting(self) -> List[Request]: Returns: List of requests that were scheduled """ + # Attempt synchronous SSD promotion for any ssd_pending requests + # before scheduling. This keeps SSD I/O out of fetch() while + # avoiding engine modifications. + if hasattr(self, "_ssd_tier") and self._ssd_tier is not None: + self._try_promote_ssd_pending() + scheduled = [] while self.waiting and len(self.running) < self.config.max_num_seqs: @@ -1922,15 +2005,27 @@ def _schedule_waiting(self) -> List[Request]: request.remaining_tokens = request.prompt_token_ids tokens_to_process = request.prompt_token_ids - # Build per-request logits_processors from repetition_penalty + # Build per-request logits_processors from repetition_penalty and + # any caller-supplied extras (e.g. JSON schema constrained + # decoding). rep_penalty = request.sampling_params.repetition_penalty - lp = None + extra_lp = request.sampling_params.logits_processors or [] + combined_lp: list = [] if rep_penalty and rep_penalty != 1.0: - lp = make_logits_processors(repetition_penalty=rep_penalty) + combined_lp.extend( + make_logits_processors(repetition_penalty=rep_penalty) + ) logger.info( f"[rep_penalty] request={request.request_id[:12]} " - f"penalty={rep_penalty} processors={len(lp)}" + f"penalty={rep_penalty}" + ) + if extra_lp: + combined_lp.extend(extra_lp) + logger.info( + f"[logits_proc] request={request.request_id[:12]} " + f"extra_processors={len(extra_lp)}" ) + lp = combined_lp # Insert into BatchGenerator with optional cache. # Wrap in try/except: if cache shapes are incompatible @@ -1939,9 +2034,10 @@ def _schedule_waiting(self) -> List[Request]: insert_kwargs = { "max_tokens": [request.sampling_params.max_tokens], "caches": [cache_to_use] if cache_to_use else None, + # Always pass logits_processors (even empty list) so that + # mlx_lm BatchGenerator never stores None per-sequence. + "logits_processors": [lp] if lp else [[]], } - if lp: - insert_kwargs["logits_processors"] = [lp] try: uids = self.batch_generator.insert( [tokens_to_process], @@ -2259,6 +2355,11 @@ def _is_cache_corruption_error(self, error: Exception) -> bool: error_str = str(error) return any(pattern in error_str for pattern in CACHE_CORRUPTION_PATTERNS) + def _is_stream_thread_error(self, error: Exception) -> bool: + """Check if an error indicates MLX stream/thread ownership mismatch.""" + error_str = str(error) + return "no Stream(" in error_str or "no Stream(gpu" in error_str + def _recover_from_cache_error(self) -> None: """Recover from cache corruption error.""" # Properly close batch generator (this is the source of the corruption) @@ -2407,6 +2508,8 @@ def step(self, max_retries: int = 1) -> SchedulerOutput: else: raise except Exception as e: + if self._is_stream_thread_error(e): + raise import traceback logger.error( @@ -2583,6 +2686,24 @@ def get_cache_stats(self) -> Optional[Dict[str, Any]]: return self.prefix_cache.get_stats() return None + def clear_runtime_caches(self) -> Dict[str, bool]: + """Clear prefix-cache state without resetting scheduler/request state.""" + cleared = { + "paged_cache": False, + "memory_aware_cache": False, + "prefix_cache": False, + } + if self.block_aware_cache is not None: + self.block_aware_cache.clear() + cleared["paged_cache"] = True + if self.memory_aware_cache is not None: + self.memory_aware_cache.clear() + cleared["memory_aware_cache"] = True + if self.prefix_cache is not None: + self.prefix_cache.clear() + cleared["prefix_cache"] = True + return cleared + def reset(self) -> None: """Reset the scheduler state.""" # Drain any pending deferred aborts @@ -2603,12 +2724,10 @@ def reset(self) -> None: self._current_sampler_params = None # Clear caches - if self.block_aware_cache is not None: - self.block_aware_cache.clear() - if self.memory_aware_cache is not None: - self.memory_aware_cache.clear() - if self.prefix_cache is not None: - self.prefix_cache.clear() + self.clear_runtime_caches() + + # Close SSD tier on reset + self.close_ssd_tier() def deep_reset(self) -> None: """ @@ -2657,3 +2776,216 @@ def load_cache_from_disk(self, cache_dir: str) -> int: return self.memory_aware_cache.load_from_disk(cache_dir) logger.info("[cache_persist] no memory-aware cache to load into") return 0 + + def clear_prefix_cache(self) -> None: + """Clear the in-memory prefix cache (keeps disk cache untouched).""" + if self.memory_aware_cache is not None and hasattr( + self.memory_aware_cache, "clear" + ): + self.memory_aware_cache.clear() + logger.info("[clear_prefix_cache] memory-aware cache cleared") + return + if self.prefix_cache is not None and hasattr(self.prefix_cache, "clear"): + self.prefix_cache.clear() + logger.info("[clear_prefix_cache] prefix cache cleared") + + def close_ssd_tier(self) -> None: + """Shut down the SSD cache tier if present.""" + if hasattr(self, "_ssd_tier") and self._ssd_tier is not None: + self._ssd_tier.close() + self._ssd_tier = None + logger.info("SSD cache tier closed") + + def _try_promote_ssd_pending(self) -> None: + """Attempt synchronous SSD promotion for waiting requests tagged ssd_pending. + + Called from _schedule_waiting() before requests are moved to running. + Reads SSD entries synchronously (disk I/O stays out of fetch() per spec). + """ + for request in self.waiting: + if getattr(request, "cache_hit_type", None) != "ssd_pending": + continue + + candidate = getattr(request, "_ssd_candidate", None) + if candidate is None: + continue + + memory_bytes = candidate["memory_bytes"] + + # Check RAM budget availability + if self.memory_aware_cache is None: + request.cache_hit_type = "miss" + continue + + if ( + self.memory_aware_cache._current_memory + memory_bytes + > self.memory_aware_cache._max_memory + ): + self._ssd_tier._stats.promotion_failures += 1 + request.cache_hit_type = "miss" + logger.info( + f"[ssd_promote] request={request.request_id[:12]} " + f"budget denied ({memory_bytes} bytes)" + ) + continue + + # Tentatively reserve budget + self.memory_aware_cache._current_memory += memory_bytes + + # Use the SSD entry's actual token count for read and store, + # NOT the full prompt tokens. For prefix hits these differ. + matched_count = candidate["matched_tokens"] + matched_tokens = tuple(request.prompt_token_ids[:matched_count]) + + try: + cache_layers = self._ssd_tier._read_entry( + matched_tokens, candidate["file_path"] + ) + except Exception: + self.memory_aware_cache._current_memory -= memory_bytes + self._ssd_tier._stats.promotion_failures += 1 + request.cache_hit_type = "miss" + logger.exception( + f"[ssd_promote] request={request.request_id[:12]} " + f"disk read failed" + ) + continue + + if cache_layers is None: + self.memory_aware_cache._current_memory -= memory_bytes + self._ssd_tier._stats.promotion_failures += 1 + request.cache_hit_type = "miss" + continue + + # Release tentative budget (store() will account properly) + self.memory_aware_cache._current_memory -= memory_bytes + + # Reconstruct and store under the matched prefix tokens + reconstructed = self._reconstruct_ssd_layers(cache_layers) + if reconstructed is None: + request.cache_hit_type = "miss" + continue + + self.memory_aware_cache.store( + list(matched_tokens), reconstructed, evict_prefixes=False + ) + + request.prompt_cache = reconstructed + request.cached_tokens = matched_count + request.remaining_tokens = request.prompt_token_ids[matched_count:] + request.cache_hit_type = "ssd_hit" + + self._ssd_tier._stats.ssd_hits += 1 + self._ssd_tier._index.touch(matched_tokens) + + logger.info( + f"[ssd_promote] request={request.request_id[:12]} " + f"{candidate['match_type']} promote: {matched_count}/{len(request.prompt_token_ids)} tokens from SSD, " + f"{len(request.remaining_tokens)} remaining" + ) + + async def promote_from_ssd(self, request) -> bool: + """Promote a cold-tier cache entry for a request (async version). + + Alternative to _try_promote_ssd_pending() for callers with an + async event loop. Uses asyncio.to_thread for non-blocking disk I/O. + + Returns True if promotion succeeded and request was updated. + """ + if not hasattr(self, "_ssd_tier") or self._ssd_tier is None: + return False + + candidate = getattr(request, "_ssd_candidate", None) + if candidate is None: + return False + + def reserve_budget(nbytes: int) -> bool: + """Tentatively reserve RAM budget for promotion.""" + if self.memory_aware_cache is None: + return False + if ( + self.memory_aware_cache._current_memory + nbytes + > self.memory_aware_cache._max_memory + ): + return False + self.memory_aware_cache._current_memory += nbytes + return True + + def release_budget(nbytes: int) -> None: + """Release tentatively reserved budget on failure.""" + if self.memory_aware_cache is not None: + self.memory_aware_cache._current_memory -= nbytes + + # Use matched token count, not full prompt, for prefix hits + matched_count = candidate.get("matched_tokens", len(request.prompt_token_ids)) + matched_tokens = tuple(request.prompt_token_ids[:matched_count]) + + cache_layers = await self._ssd_tier.async_promote( + matched_tokens, reserve_budget, release_budget + ) + + if cache_layers is None: + request.cache_hit_type = "miss" + return False + + # Release tentative budget — store() will account properly + release_budget(candidate["memory_bytes"]) + + # Reconstruct cache objects from deserialized layer dicts + reconstructed = self._reconstruct_ssd_layers(cache_layers) + if reconstructed is None: + request.cache_hit_type = "miss" + return False + + # Store in RAM cache under the matched prefix tokens + self.memory_aware_cache.store( + list(matched_tokens), reconstructed, evict_prefixes=False + ) + + request.prompt_cache = reconstructed + request.cached_tokens = matched_count + request.remaining_tokens = request.prompt_token_ids[matched_count:] + request.cache_hit_type = "ssd_hit" + + logger.info( + f"[ssd_promote] request={request.request_id[:12]} " + f"{candidate.get('match_type', 'exact')} promote: " + f"{matched_count}/{len(request.prompt_token_ids)} tokens from SSD, " + f"{len(request.remaining_tokens)} remaining" + ) + return True + + def _reconstruct_ssd_layers(self, layer_dicts: list[dict]) -> list | None: + """Reconstruct cache objects from deserialized layer dicts. + + Converts numpy arrays back to MLX arrays and creates KVCache objects. + """ + try: + from mlx_lm.models.cache import KVCache + + result = [] + for ld in layer_dicts: + if "keys" in ld and "values" in ld: + kv = KVCache() + kv.keys = mx.array(ld["keys"]) + kv.values = mx.array(ld["values"]) + kv.offset = ld["offset"] + for attr in ("max_size", "keep", "step", "_idx"): + if attr in ld: + setattr(kv, attr, ld[attr]) + result.append(kv) + elif "state" in ld: + # ArraysCache — need to wrap in a compatible object + state_arrays = [mx.array(a) for a in ld["state"]] + # Create a simple namespace-like object + layer_obj = type("ArraysCacheLayer", (), {"state": state_arrays})() + result.append(layer_obj) + else: + logger.warning( + f"[ssd_promote] unknown layer dict format: {list(ld.keys())}" + ) + return None + return result + except Exception as e: + logger.warning(f"[ssd_promote] reconstruction failed: {e}") + return None diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 6cd9581bf..f9b875a53 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -39,22 +39,26 @@ import argparse import asyncio +import copy +from dataclasses import dataclass import json import logging import os import re import secrets -import tempfile +import socket as _socket import threading import time import uuid -from collections import defaultdict +from collections import OrderedDict, defaultdict from collections.abc import AsyncIterator +from contextlib import suppress import uvicorn from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile from fastapi.responses import Response, StreamingResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from pydantic import BaseModel from starlette.routing import Match # Import from new modular API @@ -93,11 +97,42 @@ Message, # noqa: F401 ModelInfo, # noqa: F401 ModelsResponse, + RerankRequest, + RerankResponse, + RerankResult, + RerankUsage, ToolCall, Usage, # noqa: F401 VideoUrl, # noqa: F401 ) +from .api.responses_models import ( + ResponseCompletedEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseCreatedEvent, + ResponseFunctionCallArgumentsDeltaEvent, + ResponseFunctionCallItem, + ResponseFunctionCallOutputItem, + ResponseFunctionTool, + ResponseIncompleteDetails, + ResponseInProgressEvent, + ResponseMessageItem, + ResponseObject, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputTextDeltaEvent, + ResponseOutputTextDoneEvent, + ResponseReasoningItem, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, + ResponseReasoningTextPart, + ResponseTextContentPart, + ResponsesRequest, + ResponsesUsage, +) from .api.tool_calling import ( + StreamingJsonFenceStripper, + build_json_logits_processor, build_json_system_prompt, convert_tools_for_template, parse_json_output, @@ -109,24 +144,53 @@ extract_multimodal_content, is_mllm_model, # noqa: F401 ) +from .audio_limits import ( + DEFAULT_MAX_AUDIO_UPLOAD_BYTES, + DEFAULT_MAX_AUDIO_UPLOAD_MB, + DEFAULT_MAX_TTS_INPUT_CHARS, + save_upload_with_limit, + validate_tts_input_length, +) +from .cli_arg_types import make_json_object_arg_parser from .engine import BaseEngine, BatchedEngine, GenerationOutput, SimpleEngine +from .endpoint_model_policies import ( + resolve_embedding_model_name, + resolve_stt_model_name, + resolve_tts_model_name, +) +from .engine.base import suspend_cancellation +from .lifecycle import ModelSpec, ResidencyManager from .metrics import metrics as _metrics -from .tool_parsers import ToolParserManager +from .tool_parsers import ToolParserManager, get_parser_stop_tokens logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +_IMPORTED_SIMPLE_ENGINE = SimpleEngine + # Global engine instance _engine: BaseEngine | None = None _model_name: str | None = None _model_path: str | None = ( None # Actual model path (for cache dir, not affected by --served-model-name) ) +_warm_prompts_path: str | None = None # Path to JSON of prompts to pre-warm at startup +_default_model_key: str | None = None _default_max_tokens: int = 32768 +_max_request_tokens: int = 32768 _default_timeout: float = 300.0 # Default request timeout in seconds (5 minutes) _default_temperature: float | None = None # Set via --default-temperature _default_top_p: float | None = None # Set via --default-top-p +_default_chat_template_kwargs: dict[str, object] | None = None _metrics_enabled = False +_max_audio_upload_bytes: int = DEFAULT_MAX_AUDIO_UPLOAD_BYTES +_max_tts_input_chars: int = DEFAULT_MAX_TTS_INPUT_CHARS +_force_mllm_model: bool = False +_auto_unload_idle_seconds: float = 0.0 +_lazy_load_model: bool = False +_residency_manager: ResidencyManager | None = None +_lifecycle_task: asyncio.Task | None = None +_lifespan_active: bool = False _FALLBACK_TEMPERATURE = 0.7 _FALLBACK_TOP_P = 0.9 @@ -150,6 +214,292 @@ def _resolve_top_p(request_value: float | None) -> float: return _FALLBACK_TOP_P +def _resolve_request_max_tokens(requested_value: int | None) -> int: + """Resolve and validate a request's max_tokens budget.""" + if requested_value is None: + return _default_max_tokens + if requested_value > _max_request_tokens: + raise HTTPException( + status_code=400, + detail=f"max_tokens exceeds server limit ({_max_request_tokens})", + ) + return requested_value + + +def _resolve_chat_template_kwargs( + request_value: dict[str, object] | None, +) -> dict[str, object]: + """Resolve chat template kwargs: request > server default > empty dict.""" + resolved: dict[str, object] = {} + if _default_chat_template_kwargs: + resolved.update(_default_chat_template_kwargs) + if request_value: + resolved.update(request_value) + return resolved + + +@dataclass +class PreparedChatInvocation: + """Fully prepared inputs for a single engine.chat/stream_chat call.""" + + messages: list[dict] + chat_kwargs: dict[str, object] + response_format: object | None + json_logits_processor: object | None + + +def _prepare_chat_messages( + engine: BaseEngine, + request_messages: list[Message | dict], +) -> tuple[list[dict], list, list, bool]: + """Normalize messages and collect media once for both stream/non-stream paths.""" + is_mllm = bool(getattr(engine, "is_mllm", False)) + preserve_native = bool(getattr(engine, "preserve_native_tool_format", False)) + + if is_mllm: + # For MLLM models, keep original messages with embedded images + # (MLLM.chat() extracts images from message content internally) + messages = [] + for msg in request_messages: + if hasattr(msg, "model_dump"): + msg_dict = msg.model_dump(exclude_none=True) + else: + raw = dict(msg) + msg_dict = {k: v for k, v in raw.items() if v is not None} + messages.append(msg_dict) + images, videos = [], [] # MLLM extracts these from messages + logger.debug(f"MLLM: Processing {len(messages)} messages") + # Convert tool_call arguments from JSON string to dict so that + # chat templates can iterate them (e.g. GLM-4.6V calls .items()). + # The LLM path does this inside extract_multimodal_content(), but + # the MLLM path bypasses that function. + if preserve_native: + for msg_dict in messages: + for tc in msg_dict.get("tool_calls") or []: + func = tc.get("function") or {} + args = func.get("arguments") + if isinstance(args, str): + try: + func["arguments"] = json.loads(args) + except (json.JSONDecodeError, ValueError): + pass + messages = _normalize_messages(messages) + else: + # For LLM, extract text, images, and videos separately + messages, images, videos = extract_multimodal_content( + request_messages, + preserve_native_format=preserve_native, + ) + messages = _normalize_messages(messages) + + has_media = bool(images or videos) + if is_mllm and not has_media: + # MLLM extracts media from messages directly, so images/videos are + # always empty. Check message content for video/image types instead. + for msg in request_messages: + content = msg.content if hasattr(msg, "content") else msg.get("content", "") + if isinstance(content, list): + for item in content: + item_type = ( + item.type + if hasattr(item, "type") + else (item.get("type", "") if isinstance(item, dict) else "") + ) + if item_type in ("image_url", "image", "video", "video_url"): + has_media = True + break + if has_media: + break + + return messages, images, videos, has_media + + +def _prepare_json_logits_processor( + engine: BaseEngine, + messages: list[dict], + response_format: object | None, + *, + tools: list | None, + tool_choice: object | None, + log_context: str | None = None, +) -> tuple[list[dict], object | None]: + """Inject response_format instruction and build constrained decoding processor.""" + json_logits_processor = None + if not response_format: + return messages, json_logits_processor + + json_instruction = build_json_system_prompt(response_format) + if json_instruction: + messages = _inject_json_instruction(messages, json_instruction) + + # ``tools`` + ``response_format`` is undefined in OpenAI; skip constraints + # when tools are active so tool-call markup can still be emitted. + if tools and tool_choice != "none": + return messages, json_logits_processor + + tokenizer_obj = _get_engine_tokenizer(engine) + if tokenizer_obj is None: + return messages, json_logits_processor + + try: + json_logits_processor = build_json_logits_processor( + response_format, tokenizer_obj + ) + except Exception as exc: + logger.warning("Failed to build JSON logits processor: %s", exc) + json_logits_processor = None + + if json_logits_processor is not None: + log_label = f" for {log_context}" if log_context else "" + logger.info( + "Constrained decoding enabled%s response_format.type=%s", + log_label, + ( + getattr(response_format, "type", None) + if not isinstance(response_format, dict) + else response_format.get("type") + ), + ) + + return messages, json_logits_processor + + +def _prepare_chat_completion_invocation( + engine: BaseEngine, + request: ChatCompletionRequest, + effective_max_tokens: int, +) -> PreparedChatInvocation: + """Precompute messages, kwargs, and decoding constraints for chat completions.""" + messages, images, videos, has_media = _prepare_chat_messages( + engine, request.messages + ) + response_format = request.response_format + messages, json_logits_processor = _prepare_json_logits_processor( + engine, + messages, + response_format, + tools=request.tools, + tool_choice=request.tool_choice, + ) + + rep_penalty = request.repetition_penalty + chat_kwargs = { + "max_tokens": effective_max_tokens, + "temperature": _resolve_temperature(request.temperature), + "top_p": _resolve_top_p(request.top_p), + "top_k": request.top_k or 0, + "min_p": request.min_p or 0.0, + "presence_penalty": request.presence_penalty or 0.0, + "repetition_penalty": request.repetition_penalty or 1.0, + } + if rep_penalty is not None: + chat_kwargs["repetition_penalty"] = rep_penalty + + if has_media: + chat_kwargs["images"] = images if images else None + chat_kwargs["videos"] = videos if videos else None + video_fps = getattr(request, "video_fps", None) + if video_fps: + chat_kwargs["video_fps"] = video_fps + video_max_frames = getattr(request, "video_max_frames", None) + if video_max_frames: + chat_kwargs["video_max_frames"] = video_max_frames + + if request.specprefill is not None: + chat_kwargs["specprefill"] = request.specprefill + if request.specprefill_keep_pct is not None: + chat_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct + resolved_chat_template_kwargs = _resolve_chat_template_kwargs( + request.chat_template_kwargs + ) + if resolved_chat_template_kwargs: + chat_kwargs["chat_template_kwargs"] = resolved_chat_template_kwargs + + if request.enable_thinking is not None: + chat_kwargs["enable_thinking"] = request.enable_thinking + + if request.tools and request.tool_choice != "none": + template_tools = convert_tools_for_template(request.tools) + template_tools, messages = _apply_forced_tool_choice( + request.tool_choice, template_tools, messages, chat_kwargs + ) + chat_kwargs["tools"] = template_tools + + parser_name = _tool_call_parser if _enable_auto_tool_choice else None + merged_stop = get_parser_stop_tokens(parser_name, request.stop) + if merged_stop: + chat_kwargs["stop"] = merged_stop + + if json_logits_processor is not None: + existing = chat_kwargs.get("logits_processors") or [] + chat_kwargs["logits_processors"] = list(existing) + [json_logits_processor] + # Constrained decoding is incompatible with reasoning parsers: + # suppress implicit thinking if caller didn't explicitly set it. + if request.enable_thinking is None: + request.enable_thinking = False + chat_kwargs["enable_thinking"] = False + + return PreparedChatInvocation( + messages=messages, + chat_kwargs=chat_kwargs, + response_format=response_format, + json_logits_processor=json_logits_processor, + ) + + +def _prepare_anthropic_invocation( + engine: BaseEngine, + openai_request: ChatCompletionRequest, + effective_max_tokens: int, +) -> PreparedChatInvocation: + """Precompute messages, kwargs, and decoding constraints for Anthropic API.""" + messages, _, _, _ = _prepare_chat_messages(engine, openai_request.messages) + response_format = openai_request.response_format + messages, json_logits_processor = _prepare_json_logits_processor( + engine, + messages, + response_format, + tools=openai_request.tools, + tool_choice=openai_request.tool_choice, + log_context="Anthropic", + ) + + chat_kwargs = { + "max_tokens": effective_max_tokens, + "temperature": openai_request.temperature, + "top_p": openai_request.top_p, + "top_k": openai_request.top_k or 0, + "min_p": openai_request.min_p or 0.0, + "presence_penalty": openai_request.presence_penalty or 0.0, + "repetition_penalty": openai_request.repetition_penalty or 1.0, + } + resolved_chat_template_kwargs = _resolve_chat_template_kwargs( + openai_request.chat_template_kwargs + ) + if resolved_chat_template_kwargs: + chat_kwargs["chat_template_kwargs"] = resolved_chat_template_kwargs + + if openai_request.tools and openai_request.tool_choice != "none": + template_tools = convert_tools_for_template(openai_request.tools) + template_tools, messages = _apply_forced_tool_choice( + openai_request.tool_choice, template_tools, messages, chat_kwargs + ) + chat_kwargs["tools"] = template_tools + + if json_logits_processor is not None: + existing = chat_kwargs.get("logits_processors") or [] + chat_kwargs["logits_processors"] = list(existing) + [json_logits_processor] + # Suppress thinking: constrained decoding prevents tags. + chat_kwargs["enable_thinking"] = False + + return PreparedChatInvocation( + messages=messages, + chat_kwargs=chat_kwargs, + response_format=response_format, + json_logits_processor=json_logits_processor, + ) + + # Global MCP manager _mcp_manager = None _mcp_executor = None @@ -158,6 +508,10 @@ def _resolve_top_p(request_value: float | None) -> float: _embedding_engine = None _embedding_model_locked: str | None = None # Set when --embedding-model is used +# Global reranker engine (lazy loaded) +_rerank_engine = None +_rerank_model_locked: str | None = None # Set when --rerank-model is used + # API key authentication _api_key: str | None = None _auth_warning_logged: bool = False @@ -169,39 +523,135 @@ def _resolve_top_p(request_value: float | None) -> float: _enable_auto_tool_choice: bool = False _tool_call_parser: str | None = None # Parser name: auto, mistral, qwen, llama, hermes _tool_parser_instance = None # Instantiated parser +_responses_store: OrderedDict[str, dict] = OrderedDict() +_RESPONSES_STORE_MAX_SIZE: int = 1000 # Pattern to strip leaked tool call markup from content output. # Safety net: the tool parser should consume these, but if it doesn't # (e.g. malformed JSON, stray closing tags), strip them before emitting. _TOOL_MARKUP_PATTERN = re.compile(r"|") +_STREAMING_TOOL_MARKERS = ( + "", + "<|tool_call>", + "", + ' str: + """Escape control characters before logging untrusted text.""" + text = str(value) + escaped: list[str] = [] + for ch in text: + if ch == "\n": + escaped.append("\\n") + elif ch == "\r": + escaped.append("\\r") + elif ch == "\t": + escaped.append("\\t") + elif ch.isprintable(): + escaped.append(ch) + else: + code = ord(ch) + if code <= 0xFF: + escaped.append(f"\\x{code:02x}") + else: + escaped.append(f"\\u{code:04x}") + sanitized = "".join(escaped) + if limit is not None and len(sanitized) > limit: + return sanitized[:limit] + "..." + return sanitized + + +def _log_and_raise_internal_error(log_prefix: str, exc: Exception, detail: str) -> None: + """Log a sanitized exception string and raise a generic 500 response.""" + logger.error("%s: %s", log_prefix, _sanitize_log_text(exc, limit=500)) + raise HTTPException(status_code=500, detail=detail) + + +# Lifecycle startup coordination — an Event lets the lifecycle loop block +# efficiently instead of polling with short sleeps. Created lazily so +# it binds to the correct event loop at runtime rather than import time. +# +# Important: the Event is bound to the loop that was running when +# _get_idle_unload_event() is first called. In production this is always +# the single uvicorn event loop, but test fixtures must reset this to None +# between tests to avoid cross-loop contamination when pytest creates +# fresh loops per test. +_idle_unload_enabled: asyncio.Event | None = None + + +def _get_idle_unload_event() -> asyncio.Event: + """Return the idle-unload gate event, creating it on first use. + + The returned Event is bound to the running loop at creation time. + Reset ``_idle_unload_enabled`` to ``None`` when tearing down the + server or switching event loops (e.g. in test fixtures). + """ + global _idle_unload_enabled + if _idle_unload_enabled is None: + _idle_unload_enabled = asyncio.Event() + _idle_unload_enabled.set() + return _idle_unload_enabled + + +def _invalidate_tool_parser_cache(reason: str | None = None) -> None: + """Drop cached parser state when the serving tokenizer changes.""" + global _tool_parser_instance + if _tool_parser_instance is None: + return + + if reason: + logger.debug(f"Invalidating tool parser cache: {reason}") + _tool_parser_instance = None -def _load_prefix_cache_from_disk() -> None: + +def _load_prefix_cache_from_disk(engine: BaseEngine | None = None) -> None: """Load prefix cache from disk during startup.""" + target_engine = engine or _engine + if target_engine is None: + return + try: d = _get_cache_dir() logger.info(f"[lifespan] Loading prefix cache from {d}") - loaded = _engine.load_cache_from_disk(d) + loaded = target_engine.load_cache_from_disk(d) if loaded > 0: logger.info(f"[lifespan] Loaded {loaded} prefix cache entries") else: logger.info("[lifespan] No prefix cache entries found on disk") except Exception as e: - logger.warning(f"[lifespan] Failed to load cache from disk: {e}", exc_info=True) + logger.warning( + "[lifespan] Failed to load cache from disk: %s", + _sanitize_log_text(e, limit=500), + ) -def _save_prefix_cache_to_disk() -> None: +def _save_prefix_cache_to_disk(engine: BaseEngine | None = None) -> None: """Save prefix cache to disk during shutdown.""" + target_engine = engine or _engine + if target_engine is None: + return + try: d = _get_cache_dir() logger.info(f"[lifespan] Saving prefix cache to {d}") - saved = _engine.save_cache_to_disk(d) + saved = target_engine.save_cache_to_disk(d) if saved: logger.info(f"[lifespan] Saved prefix cache to {d}") else: logger.info("[lifespan] No cache to save") except Exception as e: - logger.warning(f"[lifespan] Failed to save cache to disk: {e}", exc_info=True) + logger.warning( + "[lifespan] Failed to save cache to disk: %s", + _sanitize_log_text(e, limit=500), + ) def _get_cache_dir() -> str: @@ -223,42 +673,298 @@ def _get_cache_dir() -> str: return cache_dir +def _build_engine(spec: ModelSpec) -> BaseEngine: + """Construct an engine instance from a model spec without starting it.""" + if spec.use_batching: + from .engine.batched import BatchedEngine + + logger.info(f"Preparing BatchedEngine for residency: {spec.model_name}") + return BatchedEngine( + model_name=spec.model_name, + scheduler_config=spec.scheduler_config, + stream_interval=spec.stream_interval, + force_mllm=spec.force_mllm, + ) + + from .engine.simple import SimpleEngine + + logger.info(f"Preparing SimpleEngine for residency: {spec.model_name}") + return SimpleEngine( + model_name=spec.model_name, + force_mllm=spec.force_mllm, + mtp=spec.mtp, + prefill_step_size=spec.prefill_step_size, + specprefill_enabled=spec.specprefill_enabled, + specprefill_threshold=spec.specprefill_threshold, + specprefill_keep_pct=spec.specprefill_keep_pct, + specprefill_draft_model=spec.specprefill_draft_model, + ) + + +async def _engine_factory(spec: ModelSpec) -> BaseEngine: + """Async engine factory used by the residency manager.""" + return _build_engine(spec) + + +async def _run_blocking_engine_cache_io(io_fn, engine: BaseEngine) -> None: + """Run blocking cache persistence off the event loop. + + If the caller is canceled while waiting, finish the in-flight thread before + propagating cancellation so engine state cannot keep mutating in the + background after lifecycle cleanup has started. + """ + task = asyncio.create_task(asyncio.to_thread(io_fn, engine)) + try: + await asyncio.shield(task) + except asyncio.CancelledError: + with suspend_cancellation(): + while not task.done(): + try: + await asyncio.shield(task) + except asyncio.CancelledError: + continue + except Exception: + break + raise + + +async def _restore_engine_state(spec: ModelSpec, engine: BaseEngine) -> None: + """Restore engine-local state, such as prefix cache, after a cold load.""" + if hasattr(engine, "load_cache_from_disk"): + await _run_blocking_engine_cache_io(_load_prefix_cache_from_disk, engine) + + +async def _persist_engine_state(spec: ModelSpec, engine: BaseEngine) -> None: + """Persist engine-local state before an idle unload or shutdown unload.""" + if hasattr(engine, "save_cache_to_disk"): + await _run_blocking_engine_cache_io(_save_prefix_cache_to_disk, engine) + + +def _activate_engine(engine: BaseEngine | None) -> BaseEngine | None: + """Set the global engine pointer and refresh parser-sensitive state.""" + global _engine + + if engine is not _engine: + _invalidate_tool_parser_cache("resident engine changed") + _engine = engine + if _engine is not None: + _engine.preserve_native_tool_format = _detect_native_tool_support() + return _engine + + +def _sync_engine_from_residency() -> BaseEngine | None: + """Sync the global engine pointer from the residency manager state. + + Safety: all callers run on the single-threaded asyncio event loop and do not + yield between reading the residency state and writing ``_engine``, so no + additional locking is required. + """ + if _residency_manager is None or _default_model_key is None: + return _engine + + return _activate_engine(_residency_manager.get_engine(_default_model_key)) + + +def _get_lifecycle_status() -> dict | None: + """Get lifecycle status for the default resident if lifecycle is enabled.""" + if _residency_manager is None or _default_model_key is None: + return None + return _residency_manager.get_status(_default_model_key) + + +def _public_lifecycle_status(lifecycle: dict | None) -> dict | None: + """Return residency status safe for unauthenticated public endpoints.""" + if lifecycle is None: + return None + public = dict(lifecycle) + if _model_name: + public["model_name"] = _model_name + # Surface a generic error indicator without exposing raw exception text. + if "last_error" in public: + public["last_error"] = ( + "model_load_failed" if public["last_error"] is not None else None + ) + return public + + +async def _lifecycle_loop() -> None: + """Background idle-unload loop for the default resident.""" + while True: + if _residency_manager is None or _default_model_key is None: + await asyncio.sleep(1.0) + continue + + # Block until idle-unload is enabled instead of polling. + await _get_idle_unload_event().wait() + + try: + await _residency_manager.unload_if_idle(_default_model_key) + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Idle unload iteration failed") + finally: + _sync_engine_from_residency() + + sleep_for = min(_auto_unload_idle_seconds / 2, 5.0) + await asyncio.sleep(sleep_for) + + +async def _acquire_default_engine(*, count_activity: bool = True) -> BaseEngine: + """Acquire the default engine, auto-loading via the residency manager if needed.""" + if _residency_manager is None or _default_model_key is None: + return get_engine() + + if count_activity: + engine = await _residency_manager.acquire(_default_model_key) + else: + engine = await _residency_manager.acquire( + _default_model_key, + count_activity=False, + ) + activated_engine = _activate_engine(engine) + if activated_engine is None: + raise HTTPException(status_code=503, detail="Model not loaded") + return activated_engine + + +async def _release_default_engine(*, count_activity: bool = True) -> None: + """Release the default engine after request processing.""" + if _residency_manager is None or _default_model_key is None: + return + + if count_activity: + await _residency_manager.release(_default_model_key) + else: + await _residency_manager.release(_default_model_key, count_activity=False) + _sync_engine_from_residency() + + async def lifespan(app: FastAPI): """FastAPI lifespan for startup/shutdown events.""" - global _engine, _mcp_manager + global _engine, _mcp_manager, _lifecycle_task, _lifespan_active + primary_exc: BaseException | None = None + try: + _get_idle_unload_event().clear() + + # Startup: ensure resident is loaded on the serving event loop when lifecycle + # management is enabled, unless lazy startup is requested. + if _residency_manager is not None and _default_model_key is not None: + if not _lazy_load_model: + await _residency_manager.ensure_loaded(_default_model_key) + _sync_engine_from_residency() + elif ( + _engine is not None and hasattr(_engine, "_loaded") and not _engine._loaded + ): + await _engine.start() + + # Load persisted cache from disk (AFTER engine start — AsyncEngineCore must exist) + if ( + _residency_manager is None + and _engine is not None + and hasattr(_engine, "load_cache_from_disk") + ): + _load_prefix_cache_from_disk() - # Startup: Start engine if loaded (needed for BatchedEngine in uvicorn's event loop) - if _engine is not None and hasattr(_engine, "_loaded") and not _engine._loaded: - await _engine.start() + # Warm up prefix cache with user-provided prompts (AFTER disk cache load, + # so any already-persisted entries are preserved and warm-up only fills + # gaps). + if ( + _warm_prompts_path + and _engine is not None + and hasattr(_engine, "stream_chat") + ): + try: + from vllm_mlx.prompt_warmup import load_warmup_file, warm_prefix_cache + + prompts = load_warmup_file(_warm_prompts_path) + logger.info( + "[lifespan] Warming prefix cache with %d prompts from %s", + len(prompts), + _warm_prompts_path, + ) + result = await warm_prefix_cache(_engine, prompts) + logger.info( + "[lifespan] Warm-up done (%s): %d completed, %d skipped, %d prompt tokens in %.1fs", + result.get("mode", "?"), + result["count"], + result["skipped"], + result["total_prompt_tokens"], + result["elapsed_ms"] / 1000, + ) + except Exception as e: + logger.warning( + "[lifespan] Warm-up failed: %s", + _sanitize_log_text(e, limit=500), + ) - # Load persisted cache from disk (AFTER engine start — AsyncEngineCore must exist) - if _engine is not None and hasattr(_engine, "load_cache_from_disk"): - _load_prefix_cache_from_disk() + if _residency_manager is not None and _auto_unload_idle_seconds > 0: + _lifecycle_task = asyncio.create_task(_lifecycle_loop()) - # Initialize MCP if config provided - mcp_config = os.environ.get("VLLM_MLX_MCP_CONFIG") - if mcp_config: - await init_mcp(mcp_config) + # Initialize MCP if config provided + mcp_config = os.environ.get("VLLM_MLX_MCP_CONFIG") + if mcp_config: + await init_mcp(mcp_config) - yield + _get_idle_unload_event().set() + _lifespan_active = True + yield + except BaseException as exc: + primary_exc = exc - # Shutdown: Save cache to disk BEFORE stopping engine - if _engine is not None and hasattr(_engine, "save_cache_to_disk"): - _save_prefix_cache_to_disk() + cleanup_exc: BaseException | None = None + try: + # Shutdown: Save cache to disk BEFORE stopping engine + if ( + _residency_manager is None + and _engine is not None + and hasattr(_engine, "save_cache_to_disk") + ): + _save_prefix_cache_to_disk() + + # Shutdown: Close MCP connections and stop engine + if _lifecycle_task is not None: + _lifecycle_task.cancel() + with suppress(asyncio.CancelledError): + await _lifecycle_task + _lifecycle_task = None + if _mcp_manager is not None: + await _mcp_manager.stop() + logger.info("MCP manager stopped") + if _residency_manager is not None: + await _residency_manager.shutdown() + _sync_engine_from_residency() + logger.info("Lifecycle manager shut down") + elif _engine is not None: + await _engine.stop() + _engine = None + logger.info("Engine stopped") + except BaseException as exc: + cleanup_exc = exc + finally: + _get_idle_unload_event().set() + _lifespan_active = False + + if primary_exc is not None: + if cleanup_exc is not None: + logger.error( + "Lifecycle cleanup failed while preserving the original exception", + exc_info=( + type(cleanup_exc), + cleanup_exc, + cleanup_exc.__traceback__, + ), + ) + raise primary_exc - # Shutdown: Close MCP connections and stop engine - if _mcp_manager is not None: - await _mcp_manager.stop() - logger.info("MCP manager stopped") - if _engine is not None: - await _engine.stop() - logger.info("Engine stopped") + if cleanup_exc is not None: + raise cleanup_exc app = FastAPI( title="vllm-mlx API", description="OpenAI-compatible API for MLX LLM/MLLM inference on Apple Silicon", - version="0.2.1", + version="0.2.9", lifespan=lifespan, ) @@ -472,8 +1178,33 @@ def _validate_model_name(request_model: str) -> None: ) +def _get_engine_tokenizer(engine: BaseEngine | None) -> object | None: + """Return tokenizer-like parser state from the active engine.""" + if engine is None: + return None + tokenizer = getattr(engine, "tokenizer", None) + if tokenizer is not None: + return tokenizer + return getattr(engine, "_tokenizer", None) + + +def _get_or_init_tool_parser(engine: BaseEngine | None = None): + """Return the cached tool parser, initializing it from the given engine.""" + global _tool_parser_instance + + if _tool_parser_instance is None: + parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) + tokenizer = _get_engine_tokenizer(engine if engine is not None else _engine) + _tool_parser_instance = parser_cls(tokenizer) + logger.info(f"Initialized tool call parser: {_tool_call_parser}") + + return _tool_parser_instance + + def _parse_tool_calls_with_parser( - output_text: str, request: ChatCompletionRequest | None = None + output_text: str, + request: ChatCompletionRequest | None = None, + engine: BaseEngine | None = None, ) -> tuple[str, list | None]: """ Parse tool calls from model output using the configured parser. @@ -484,6 +1215,7 @@ def _parse_tool_calls_with_parser( Args: output_text: The model output text request: The original request (for context) + engine: The request-local engine to use for parser initialization Returns: Tuple of (cleaned_text, tool_calls) @@ -507,16 +1239,12 @@ def _parse_tool_calls_with_parser( # Initialize parser if needed if _tool_parser_instance is None: try: - parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) - # Get tokenizer from engine if available - tokenizer = None - if _engine is not None and hasattr(_engine, "_tokenizer"): - tokenizer = _engine._tokenizer - _tool_parser_instance = parser_cls(tokenizer) - logger.info(f"Initialized tool call parser: {_tool_call_parser}") + _get_or_init_tool_parser(engine) except Exception as e: logger.warning( - f"Failed to initialize tool parser '{_tool_call_parser}': {e}" + "Failed to initialize tool parser '%s': %s", + _tool_call_parser, + _sanitize_log_text(e, limit=500), ) logger.warning("Falling back to generic parser") return parse_tool_calls(output_text, request_dict) @@ -547,26 +1275,975 @@ def _parse_tool_calls_with_parser( # try generic parser which handles more formats (e.g. Nemotron XML) return parse_tool_calls(output_text, request_dict) except Exception as e: - logger.warning(f"Tool parser error: {e}") + logger.warning("Tool parser error: %s", _sanitize_log_text(e, limit=500)) return parse_tool_calls(output_text, request_dict) -def _detect_native_tool_support() -> bool: - """ - Detect if the active tool parser supports native tool format. +def _new_response_item_id(prefix: str) -> str: + """Generate stable OpenAI-style item ids.""" + return f"{prefix}_{uuid.uuid4().hex}" - Native format means role="tool" messages and tool_calls fields - are preserved instead of being converted to text. - Returns: - True if native format should be preserved - """ - if not _enable_auto_tool_choice or not _tool_call_parser: - return False +def _response_content_to_text(content) -> str: + """Normalize Responses API content items into plain text.""" + if content is None: + return "" + if isinstance(content, str): + return content - try: - parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) - return parser_cls.supports_native_format() + text_parts = [] + for part in content: + if isinstance(part, dict): + part_type = part.get("type") + text = part.get("text", "") + else: + part_type = getattr(part, "type", None) + text = getattr(part, "text", "") + if part_type in {"text", "input_text", "output_text"}: + text_parts.append(text) + return "\n".join(part for part in text_parts if part) + + +def _responses_tools_to_chat_tools( + tools: list[ResponseFunctionTool | dict], +) -> tuple[list[dict] | None, list[str]]: + """Convert supported Responses tools and report unsupported tool types.""" + if not tools: + return None, [] + + supported: list[dict] = [] + unsupported: list[str] = [] + + for tool in tools: + if isinstance(tool, ResponseFunctionTool): + tool_type = tool.type + tool_name = tool.name + tool_description = tool.description or "" + tool_parameters = tool.parameters + elif isinstance(tool, dict): + tool_type = tool.get("type", "unknown") + tool_name = tool.get("name", "") + tool_description = tool.get("description", "") + tool_parameters = tool.get("parameters", {}) + else: + unsupported.append(type(tool).__name__) + continue + + if tool_type == "function": + supported.append( + { + "type": "function", + "function": { + "name": tool_name, + "description": tool_description, + "parameters": tool_parameters + or {"type": "object", "properties": {}}, + }, + } + ) + else: + unsupported.append(tool_type) + + return supported or None, unsupported + + +def _responses_input_to_chat_messages(request: ResponsesRequest) -> list[dict]: + """Convert Responses API input items into chat-completions-style messages.""" + messages: list[dict] = [] + + if request.previous_response_id: + previous = _responses_store.get(request.previous_response_id) + if previous is None: + raise HTTPException( + status_code=404, + detail=f"Previous response `{request.previous_response_id}` not found", + ) + messages.extend(copy.deepcopy(previous["messages"])) + + if request.instructions: + messages.append({"role": "system", "content": request.instructions}) + + if isinstance(request.input, str): + messages.append({"role": "user", "content": request.input}) + return messages + + for item in request.input: + if isinstance(item, dict): + item_type = item.get("type", "") + if item_type == "message": + role = item.get("role", "user") + if role == "developer": + role = "system" + messages.append( + { + "role": role, + "content": _response_content_to_text(item.get("content")), + } + ) + elif item_type == "function_call": + messages.append( + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": item.get( + "call_id", _new_response_item_id("call") + ), + "type": "function", + "function": { + "name": item.get("name", ""), + "arguments": item.get("arguments", ""), + }, + } + ], + } + ) + elif item_type == "function_call_output": + messages.append( + { + "role": "tool", + "tool_call_id": item.get("call_id", ""), + "content": item.get("output", ""), + } + ) + elif item_type == "reasoning": + parts = item.get("content", []) + reasoning_text = "\n".join( + p.get("text", "") for p in parts if isinstance(p, dict) + ) + if reasoning_text: + messages.append({"role": "assistant", "content": reasoning_text}) + else: + logger.info( + "Skipping unsupported Responses input item type %r", item_type + ) + continue + + if isinstance(item, ResponseMessageItem): + role = item.role + if role == "developer": + role = "system" + messages.append( + { + "role": role, + "content": _response_content_to_text(item.content), + } + ) + elif isinstance(item, ResponseFunctionCallItem): + messages.append( + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": item.call_id, + "type": "function", + "function": { + "name": item.name, + "arguments": item.arguments, + }, + } + ], + } + ) + elif isinstance(item, ResponseFunctionCallOutputItem): + messages.append( + { + "role": "tool", + "tool_call_id": item.call_id, + "content": item.output, + } + ) + elif isinstance(item, ResponseReasoningItem): + reasoning_text = "\n".join(part.text for part in (item.content or [])) + if reasoning_text: + messages.append({"role": "assistant", "content": reasoning_text}) + else: + logger.info( + "Skipping unsupported Responses input item type %r", + getattr(item, "type", type(item).__name__), + ) + + return messages + + +def _responses_request_to_new_persisted_messages( + request: ResponsesRequest, +) -> list[dict]: + """Persist only the current request's replayable input items.""" + request_without_history = request.model_copy( + update={"previous_response_id": None, "instructions": None}, + deep=True, + ) + return _responses_input_to_chat_messages(request_without_history) + + +def _responses_request_to_persisted_messages(request: ResponsesRequest) -> list[dict]: + """Persist replayable history for chained previous_response_id requests. + + Responses `instructions` are intentionally not replayed across + `previous_response_id`, but replayable message items are. + """ + messages: list[dict] = [] + if request.previous_response_id: + previous = _responses_store.get(request.previous_response_id) + if previous is None: + raise HTTPException( + status_code=404, + detail=f"Previous response `{request.previous_response_id}` not found", + ) + messages.extend(copy.deepcopy(previous["messages"])) + messages.extend(_responses_request_to_new_persisted_messages(request)) + return messages + + +def _responses_request_to_chat_request( + request: ResponsesRequest, +) -> ChatCompletionRequest: + """Build a ChatCompletionRequest from a ResponsesRequest.""" + if request.text.format.type == "json_object": + raise HTTPException( + status_code=400, + detail="Responses text.format.type='json_object' is not supported on this backend", + ) + if request.reasoning is not None: + logger.debug("Ignoring reasoning configuration (not supported on this backend)") + + tools, unsupported_tools = _responses_tools_to_chat_tools(request.tools) + messages = _responses_input_to_chat_messages(request) + if unsupported_tools: + tool_list = ", ".join(sorted(set(unsupported_tools))) + messages.insert( + 0, + { + "role": "system", + "content": ( + "The following requested tool types are not available on this " + f"backend: {tool_list}. Do not call them." + ), + }, + ) + + system_messages = [msg for msg in messages if msg.get("role") == "system"] + non_system_messages = [msg for msg in messages if msg.get("role") != "system"] + merged_system_content = "\n\n".join( + str(msg.get("content", "")).strip() + for msg in system_messages + if str(msg.get("content", "")).strip() + ) + messages = ( + [{"role": "system", "content": merged_system_content}] + if merged_system_content + else [] + ) + non_system_messages + + return ChatCompletionRequest( + model=request.model, + messages=[Message(**msg) for msg in messages], + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_output_tokens, + stream=False, + tools=tools, + tool_choice=request.tool_choice, + chat_template_kwargs=request.chat_template_kwargs, + ) + + +def _build_responses_output_items( + text: str | None, + reasoning: str | None, + tool_calls: list[ToolCall] | None, +) -> list[ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem]: + """Convert parsed assistant output into Responses API output items.""" + output_items: list[ + ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + ] = [] + + if reasoning: + output_items.append( + ResponseReasoningItem( + id=_new_response_item_id("rs"), + content=[ResponseReasoningTextPart(text=reasoning)], + ) + ) + + if text: + output_items.append( + ResponseMessageItem( + id=_new_response_item_id("msg"), + role="assistant", + content=[ResponseTextContentPart(type="output_text", text=text)], + ) + ) + + for tool_call in tool_calls or []: + output_items.append( + ResponseFunctionCallItem( + id=_new_response_item_id("fc"), + call_id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + ) + + return output_items + + +def _response_output_items_to_chat_messages(output_items: list) -> list[dict]: + """Persist assistant output in chat-completions form for previous_response_id.""" + assistant_text_parts: list[str] = [] + assistant_tool_calls: list[dict] = [] + + for item in output_items: + if isinstance(item, ResponseMessageItem): + assistant_text_parts.append(_response_content_to_text(item.content)) + elif isinstance(item, ResponseFunctionCallItem): + assistant_tool_calls.append( + { + "id": item.call_id, + "type": "function", + "function": { + "name": item.name, + "arguments": item.arguments, + }, + } + ) + + if not assistant_text_parts and not assistant_tool_calls: + return [] + + return [ + { + "role": "assistant", + "content": "".join(assistant_text_parts), + "tool_calls": assistant_tool_calls or None, + } + ] + + +def _build_response_object( + request: ResponsesRequest, + output_items: list[ + ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + ], + prompt_tokens: int, + completion_tokens: int, + finish_reason: str | None, + response_id: str | None = None, +) -> ResponseObject: + """Build a full Responses API object.""" + response = ResponseObject( + id=response_id or _new_response_item_id("resp"), + model=_model_name or request.model, + instructions=request.instructions, + max_output_tokens=request.max_output_tokens, + metadata=request.metadata, + output=output_items, + parallel_tool_calls=request.parallel_tool_calls, + previous_response_id=request.previous_response_id, + text=request.text, + tool_choice=request.tool_choice, + tools=request.tools, + top_p=_resolve_top_p(request.top_p), + temperature=_resolve_temperature(request.temperature), + truncation=request.truncation, + user=request.user, + store=request.store, + usage=ResponsesUsage( + input_tokens=prompt_tokens, + output_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + if finish_reason == "length": + response.status = "incomplete" + response.incomplete_details = ResponseIncompleteDetails( + reason="max_output_tokens" + ) + return response + + +def _prepare_responses_request( + request: ResponsesRequest, +) -> tuple[BaseEngine, ChatCompletionRequest, list[dict], dict]: + """Prepare a Responses request for execution on the chat engine.""" + _validate_model_name(request.model) + engine = get_engine() + chat_request = _responses_request_to_chat_request(request) + + if chat_request.messages: + logger.info( + f"[REQUEST] POST /v1/responses stream={request.stream} " + f"model={request.model!r} items=" + f"{len(request.input) if isinstance(request.input, list) else 1} " + f"tools={len(request.tools)}" + ) + + messages, images, videos = extract_multimodal_content( + chat_request.messages, + preserve_native_format=engine.preserve_native_tool_format, + ) + + chat_kwargs = { + "max_tokens": chat_request.max_tokens or _default_max_tokens, + "temperature": _resolve_temperature(chat_request.temperature), + "top_p": _resolve_top_p(chat_request.top_p), + } + resolved_chat_template_kwargs = _resolve_chat_template_kwargs( + chat_request.chat_template_kwargs + ) + if resolved_chat_template_kwargs: + chat_kwargs["chat_template_kwargs"] = resolved_chat_template_kwargs + if request.tools: + chat_kwargs["tools"] = convert_tools_for_template(chat_request.tools) + if images: + chat_kwargs["images"] = images + if videos: + chat_kwargs["videos"] = videos + + return engine, chat_request, messages, chat_kwargs + + +async def _run_responses_request( + request: ResponsesRequest, + raw_request: Request, +) -> tuple[ResponseObject | None, list[dict]]: + """Execute a Responses API request against the backend chat engine.""" + engine, chat_request, messages, chat_kwargs = _prepare_responses_request(request) + + timeout = _default_timeout + output = await _wait_with_disconnect( + engine.chat(messages=messages, **chat_kwargs), + raw_request, + timeout=timeout, + ) + if output is None: + return None, [] + + cleaned_text, tool_calls = _parse_tool_calls_with_parser(output.text, chat_request) + reasoning_text = None + if _reasoning_parser: + reasoning_text, remaining_text = _reasoning_parser.extract_reasoning( + output.text + ) + if not tool_calls: + cleaned_text = remaining_text + else: + # Tool parser already stripped tool markup from cleaned_text, + # but reasoning markers (e.g. <|channel>thought...) + # remain. Run reasoning parser on cleaned_text to strip them. + _, cleaned_text = _reasoning_parser.extract_reasoning(cleaned_text or "") + + output_items = _build_responses_output_items( + clean_output_text(cleaned_text) if cleaned_text else None, + reasoning_text, + tool_calls, + ) + response_object = _build_response_object( + request=request, + output_items=output_items, + prompt_tokens=output.prompt_tokens, + completion_tokens=output.completion_tokens, + finish_reason=output.finish_reason, + ) + + persisted_messages = _responses_request_to_persisted_messages(request) + persisted_messages.extend(_response_output_items_to_chat_messages(output_items)) + if request.store: + _responses_store[response_object.id] = { + "messages": copy.deepcopy(persisted_messages), + "response": response_object.model_copy(deep=True), + } + while len(_responses_store) > _RESPONSES_STORE_MAX_SIZE: + _responses_store.popitem(last=False) + + return response_object, persisted_messages + + +async def _stream_responses_request(request: ResponsesRequest) -> AsyncIterator[str]: + """Execute a Responses API request and stream SSE events incrementally.""" + engine, chat_request, messages, chat_kwargs = _prepare_responses_request(request) + + response_id = _new_response_item_id("resp") + sequence = 1 + base_response = _build_response_object( + request=request, + output_items=[], + prompt_tokens=0, + completion_tokens=0, + finish_reason=None, + response_id=response_id, + ) + base_response.status = "in_progress" + base_response.usage = None + + yield _responses_sse_event( + "response.created", + ResponseCreatedEvent(sequence_number=sequence, response=base_response), + ) + sequence += 1 + yield _responses_sse_event( + "response.in_progress", + ResponseInProgressEvent(sequence_number=sequence, response=base_response), + ) + sequence += 1 + + prompt_tokens = 0 + completion_tokens = 0 + finish_reason = None + last_output = None + raw_accumulated_text = "" + accumulated_text = "" + accumulated_reasoning = "" + + text_item_id: str | None = None + text_output_index: int | None = None + reasoning_item_id: str | None = None + reasoning_output_index: int | None = None + next_output_index = 0 + + def _start_text_item() -> list[str]: + nonlocal text_item_id, text_output_index, next_output_index, sequence + events: list[str] = [] + if text_item_id is None: + text_item_id = _new_response_item_id("msg") + text_output_index = next_output_index + next_output_index += 1 + events.append( + _responses_sse_event( + "response.output_item.added", + ResponseOutputItemAddedEvent( + sequence_number=sequence, + output_index=text_output_index, + item=ResponseMessageItem( + id=text_item_id, + role="assistant", + status="in_progress", + content=[], + ), + ), + ) + ) + sequence += 1 + events.append( + _responses_sse_event( + "response.content_part.added", + ResponseContentPartAddedEvent( + sequence_number=sequence, + item_id=text_item_id, + output_index=text_output_index, + content_index=0, + part=ResponseTextContentPart(type="output_text", text=""), + ), + ) + ) + sequence += 1 + return events + + def _start_reasoning_item() -> list[str]: + nonlocal reasoning_item_id, reasoning_output_index, next_output_index, sequence + events: list[str] = [] + if reasoning_item_id is None: + reasoning_item_id = _new_response_item_id("rs") + reasoning_output_index = next_output_index + next_output_index += 1 + events.append( + _responses_sse_event( + "response.output_item.added", + ResponseOutputItemAddedEvent( + sequence_number=sequence, + output_index=reasoning_output_index, + item=ResponseReasoningItem( + id=reasoning_item_id, + status="in_progress", + content=[], + ), + ), + ) + ) + sequence += 1 + events.append( + _responses_sse_event( + "response.content_part.added", + ResponseContentPartAddedEvent( + sequence_number=sequence, + item_id=reasoning_item_id, + output_index=reasoning_output_index, + content_index=0, + part=ResponseReasoningTextPart(text=""), + ), + ) + ) + sequence += 1 + return events + + if _reasoning_parser: + _reasoning_parser.reset_state() + + global _tool_parser_instance + tool_parser = None + tool_accumulated_text = "" + tool_markup_possible = False + if _enable_auto_tool_choice and _tool_call_parser: + if _tool_parser_instance is None: + try: + parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) + tokenizer = None + if _engine is not None and hasattr(_engine, "_tokenizer"): + tokenizer = _engine._tokenizer + _tool_parser_instance = parser_cls(tokenizer) + logger.info( + "Initialized tool call parser for responses streaming: %s", + _tool_call_parser, + ) + except Exception as e: + logger.warning( + "Failed to init tool parser for responses streaming: %s", e + ) + if _tool_parser_instance is not None: + tool_parser = _tool_parser_instance + tool_parser.reset() + + async for output in engine.stream_chat(messages=messages, **chat_kwargs): + last_output = output + finish_reason = output.finish_reason + if hasattr(output, "prompt_tokens") and output.prompt_tokens: + prompt_tokens = output.prompt_tokens + if hasattr(output, "completion_tokens") and output.completion_tokens: + completion_tokens = output.completion_tokens + + delta_text = output.new_text or "" + if not delta_text: + continue + + previous_text = raw_accumulated_text + raw_accumulated_text += delta_text + + if _reasoning_parser: + delta_msg = _reasoning_parser.extract_reasoning_streaming( + previous_text, raw_accumulated_text, delta_text + ) + if delta_msg is None: + continue + + if delta_msg.reasoning: + for event in _start_reasoning_item(): + yield event + accumulated_reasoning += delta_msg.reasoning + yield _responses_sse_event( + "response.reasoning_text.delta", + ResponseReasoningTextDeltaEvent( + sequence_number=sequence, + item_id=reasoning_item_id, + output_index=reasoning_output_index, + content_index=0, + delta=delta_msg.reasoning, + ), + ) + sequence += 1 + + if delta_msg.content: + for event in _start_text_item(): + yield event + accumulated_text += delta_msg.content + yield _responses_sse_event( + "response.output_text.delta", + ResponseOutputTextDeltaEvent( + sequence_number=sequence, + item_id=text_item_id, + output_index=text_output_index, + content_index=0, + delta=delta_msg.content, + ), + ) + sequence += 1 + continue + + content = SPECIAL_TOKENS_PATTERN.sub("", delta_text) + if tool_parser and delta_text: + # Fast path: skip parsing until a tool-markup marker appears. + # Use _streaming_tool_markup_possible to catch all supported + # shapes (, _RESPONSES_STORE_MAX_SIZE: + _responses_store.popitem(last=False) + + yield _responses_sse_event( + "response.completed", + ResponseCompletedEvent(sequence_number=sequence, response=response_object), + ) + + +def _responses_sse_event(event_type: str, payload: BaseModel | dict) -> str: + """Encode a Responses API SSE event.""" + data = ( + payload.model_dump_json() + if isinstance(payload, BaseModel) + else json.dumps(payload) + ) + return f"event: {event_type}\ndata: {data}\n\n" + + +def _extract_reasoning_and_tool_calls( + output_text: str, + request: ChatCompletionRequest | None = None, + *, + allow_reasoning: bool = True, + engine: BaseEngine | None = None, +) -> tuple[str | None, str | None, list[ToolCall] | None]: + """ + Extract reasoning first, then parse tool calls from the cleaned content. + + Non-streaming responses can contain both a reasoning block and structured + tool calls in the same final output. If tool parsing runs first and the + response contains tools, the caller can no longer reliably recover the + reasoning segment because the usual response path skips reasoning parsing + once tool_calls is truthy. + """ + reasoning_text = None + text_for_tool_parse = output_text + + if _reasoning_parser and allow_reasoning: + reasoning_text, cleaned_reasoning_text = _reasoning_parser.extract_reasoning( + output_text + ) + if cleaned_reasoning_text is not None: + text_for_tool_parse = cleaned_reasoning_text + elif reasoning_text is not None: + text_for_tool_parse = "" + + # Skip tool parsing when the request defines no tools — otherwise the + # parser can misinterpret JSON output (e.g. response_format) as tool calls. + if request is not None and getattr(request, "tools", None): + try: + cleaned_text, tool_calls = _parse_tool_calls_with_parser( + text_for_tool_parse or "", + request, + engine=engine, + ) + except TypeError as exc: + if "unexpected keyword argument 'engine'" not in str(exc): + raise + cleaned_text, tool_calls = _parse_tool_calls_with_parser( + text_for_tool_parse or "", + request, + ) + else: + cleaned_text, tool_calls = text_for_tool_parse, None + + return reasoning_text, cleaned_text, tool_calls + + +def _detect_native_tool_support() -> bool: + """ + Detect if the active tool parser supports native tool format. + + Native format means role="tool" messages and tool_calls fields + are preserved instead of being converted to text. + + Returns: + True if native format should be preserved + """ + if not _enable_auto_tool_choice or not _tool_call_parser: + return False + + try: + parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) + return parser_cls.supports_native_format() except KeyError: # Parser not found - this is a configuration error, log as error logger.error( @@ -576,9 +2253,78 @@ def _detect_native_tool_support() -> bool: return False except Exception as e: # Unexpected error during detection - logger.warning(f"Failed to detect native tool support: {e}") + logger.warning( + "Failed to detect native tool support: %s", + _sanitize_log_text(e, limit=500), + ) + return False + + +def _tool_choice_disabled(request: ChatCompletionRequest | None) -> bool: + """Return True when tool_choice explicitly disables tool calling.""" + if request is None: return False + tool_choice = getattr(request, "tool_choice", None) + if tool_choice is None: + request_dict = request.model_dump() + tool_choice = request_dict.get("tool_choice") + return tool_choice == "none" + + +def _get_streaming_tool_parser( + request: ChatCompletionRequest | None, + engine: BaseEngine | None = None, +): + """Get a streaming-capable tool parser for this request. + + Uses the configured parser when auto tool choice is enabled, otherwise falls + back to the generic auto parser so streaming still matches the generic + non-streaming tool parsing behavior. + """ + global _tool_parser_instance + + if request is None: + return None + if _tool_choice_disabled(request): + return None + + tokenizer = _get_engine_tokenizer(engine if engine is not None else _engine) + + if _enable_auto_tool_choice and _tool_call_parser: + if _tool_parser_instance is None: + try: + _get_or_init_tool_parser(engine) + except Exception as e: + logger.warning( + "Failed to init tool parser for streaming: %s", + _sanitize_log_text(e, limit=500), + ) + return None + _tool_parser_instance.reset() + return _tool_parser_instance + + if not getattr(request, "tools", None): + return None + + try: + parser_cls = ToolParserManager.get_tool_parser("auto") + parser = parser_cls(tokenizer) + parser.reset() + return parser + except Exception as e: + logger.warning(f"Failed to init generic streaming tool parser: {e}") + return None + + +def _streaming_tool_markup_possible(text: str) -> bool: + """Heuristic marker check to avoid parser work on ordinary text chunks.""" + return ( + any(marker in text for marker in _STREAMING_TOOL_MARKERS) + or _STREAMING_BARE_BRACKET_MARKER.search(text) is not None + or _STREAMING_BARE_BRACKET_PARTIAL.search(text) is not None + ) + def load_embedding_model( model_name: str | None, @@ -608,21 +2354,54 @@ def load_embedding_model( _embedding_engine.load() +def load_reranker_model( + model_name: str | None, + *, + lock: bool = False, + reuse_existing: bool = True, +) -> None: + """Load or reuse the reranker model engine when configured.""" + global _rerank_engine, _rerank_model_locked + + if not model_name: + return + + if lock: + _rerank_model_locked = model_name + + if ( + reuse_existing + and _rerank_engine is not None + and _rerank_engine.model_name == model_name + ): + return + + from .rerank import RerankEngine + + _rerank_engine = RerankEngine(model_name) + _rerank_engine.load() + + def load_model( model_name: str, use_batching: bool = False, scheduler_config=None, stream_interval: int = 1, max_tokens: int = 32768, + max_request_tokens: int = 32768, force_mllm: bool = False, gpu_memory_utilization: float = 0.90, served_model_name: str | None = None, + trust_remote_code: bool = False, mtp: bool = False, prefill_step_size: int = 2048, specprefill_enabled: bool = False, specprefill_threshold: int = 8192, specprefill_keep_pct: float = 0.3, specprefill_draft_model: str = None, + warm_prompts_path: str | None = None, + auto_unload_idle_seconds: float = 0.0, + lazy_load_model: bool = False, ): """ Load a model (auto-detects MLLM vs LLM). @@ -633,29 +2412,121 @@ def load_model( scheduler_config: Scheduler config for batched mode stream_interval: Tokens to batch before streaming (batched mode only) max_tokens: Default max tokens for generation + max_request_tokens: Maximum max_tokens accepted from API clients force_mllm: Force loading as MLLM even if not auto-detected + trust_remote_code: Allow HuggingFace remote code execution during model/tokenizer loading mtp: Enable native MTP speculative decoding (SimpleEngine only) prefill_step_size: Chunk size for prompt prefill processing (default: 2048) specprefill_enabled: Enable SpecPrefill (SimpleEngine only) specprefill_threshold: Minimum suffix tokens to trigger SpecPrefill (default: 8192) specprefill_keep_pct: Fraction of tokens to keep (default: 0.3) specprefill_draft_model: Path to small draft model for SpecPrefill scoring + auto_unload_idle_seconds: Idle time before auto-unloading the main model. + When non-zero, the main model is managed through lifecycle + residency instead of being loaded immediately in this function. + lazy_load_model: When lifecycle residency is enabled, defer the first + resident load until the first request instead of FastAPI lifespan + startup. """ - global _engine, _model_name, _model_path, _default_max_tokens, _tool_parser_instance + global _engine, _model_name, _model_path, _default_max_tokens + global _max_request_tokens, _tool_parser_instance, _warm_prompts_path + global _default_model_key, _auto_unload_idle_seconds, _residency_manager + global _force_mllm_model, _lazy_load_model, _lifespan_active + + _warm_prompts_path = warm_prompts_path + + if max_tokens < 1: + raise ValueError("Default max tokens must be at least 1") + if max_request_tokens < 1: + raise ValueError("Max request tokens must be at least 1") + if max_tokens > max_request_tokens: + raise ValueError("Default max tokens cannot exceed max request tokens") + + if _lifespan_active: + raise RuntimeError( + "Cannot call load_model() after FastAPI lifespan startup; " + "restart the server to reconfigure the main model" + ) + + if _residency_manager is None and _engine is not None: + existing_loaded_attr = getattr(_engine, "_loaded", False) + existing_stopped_attr = getattr(_engine, "stopped", False) + existing_loaded = ( + existing_loaded_attr if isinstance(existing_loaded_attr, bool) else False + ) + existing_stopped = ( + existing_stopped_attr if isinstance(existing_stopped_attr, bool) else None + ) + existing_live = existing_loaded or existing_stopped is False + if auto_unload_idle_seconds > 0 or lazy_load_model or existing_live: + raise RuntimeError("Cannot replace an existing engine while it is live") + + if _residency_manager is not None and _default_model_key is not None: + existing_engine = _residency_manager.get_engine(_default_model_key) + existing_status = _residency_manager.get_status(_default_model_key) + existing_state = existing_status.get("state") + if ( + existing_engine is not None + or existing_status.get("active_requests", 0) > 0 + or existing_state in {"loading", "loaded", "unloading"} + ): + raise RuntimeError( + "Cannot replace an existing residency manager while it is live" + ) _default_max_tokens = max_tokens + _max_request_tokens = max_request_tokens _model_path = model_name _model_name = served_model_name or model_name + _default_model_key = "default" + _force_mllm_model = force_mllm + _auto_unload_idle_seconds = auto_unload_idle_seconds + _lazy_load_model = lazy_load_model # Reset tool parser instance when model is reloaded (tokenizer may change) - _tool_parser_instance = None + _invalidate_tool_parser_cache("model reloaded") if force_mllm: logger.info("Force MLLM mode enabled via --mllm flag") + if auto_unload_idle_seconds > 0 or lazy_load_model: + spec = ModelSpec( + model_key=_default_model_key, + model_name=model_name, + use_batching=use_batching, + scheduler_config=scheduler_config, + stream_interval=stream_interval if use_batching else 1, + max_tokens=max_tokens, + force_mllm=force_mllm, + mtp=mtp, + prefill_step_size=prefill_step_size, + specprefill_enabled=specprefill_enabled, + specprefill_threshold=specprefill_threshold, + specprefill_keep_pct=specprefill_keep_pct, + specprefill_draft_model=specprefill_draft_model, + ) + _residency_manager = ResidencyManager( + _engine_factory, + on_engine_loaded=_restore_engine_state, + on_engine_unloading=_persist_engine_state, + auto_unload_idle_seconds=auto_unload_idle_seconds, + ) + _residency_manager.register_model(spec) + _engine = None + logger.info( + "Lifecycle manager enabled: auto_unload_idle_seconds=%.1f", + auto_unload_idle_seconds, + ) + return + + _residency_manager = None + _auto_unload_idle_seconds = 0.0 + _lazy_load_model = False + if use_batching: logger.info(f"Loading model with BatchedEngine: {model_name}") _engine = BatchedEngine( model_name=model_name, + trust_remote_code=trust_remote_code, scheduler_config=scheduler_config, stream_interval=stream_interval, force_mllm=force_mllm, @@ -665,9 +2536,16 @@ def load_model( # Just log for now logger.info(f"Model loaded (batched mode): {model_name}") else: + simple_engine_cls = SimpleEngine + if simple_engine_cls is _IMPORTED_SIMPLE_ENGINE: + from .engine import simple as simple_mod + + simple_engine_cls = simple_mod.SimpleEngine + logger.info(f"Loading model with SimpleEngine: {model_name}") - _engine = SimpleEngine( + _engine = simple_engine_cls( model_name=model_name, + trust_remote_code=trust_remote_code, force_mllm=force_mllm, mtp=mtp, prefill_step_size=prefill_step_size, @@ -678,9 +2556,23 @@ def load_model( ) # Start SimpleEngine synchronously (no background loop) # Use new_event_loop() for Python 3.10+ compatibility (get_event_loop() is deprecated) + previous_loop = None + try: + previous_loop = asyncio.get_event_loop() + except RuntimeError: + previous_loop = None loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(_engine.start()) + try: + asyncio.set_event_loop(loop) + loop.run_until_complete(_engine.start()) + finally: + with suppress(Exception): + loop.run_until_complete(loop.shutdown_default_executor()) + loop.close() + if previous_loop is not None and not previous_loop.is_closed(): + asyncio.set_event_loop(previous_loop) + else: + asyncio.set_event_loop(None) model_type = "MLLM" if _engine.is_mllm else "LLM" logger.info(f"{model_type} model loaded (simple mode): {model_name}") @@ -690,6 +2582,7 @@ def load_model( logger.info(f"Native tool format enabled for parser: {_tool_call_parser}") logger.info(f"Default max tokens: {_default_max_tokens}") + logger.info(f"Max request tokens: {_max_request_tokens}") def get_usage(output: GenerationOutput) -> Usage: @@ -736,26 +2629,66 @@ async def health(): "tools_available": len(_mcp_manager.get_all_tools()), } - return { - "status": "healthy", + engine_stats = _engine.get_stats() if _engine else {} + lifecycle = _get_lifecycle_status() + health_status = ( + "unhealthy" + if lifecycle is not None and lifecycle.get("state") == "failed" + else "healthy" + ) + + payload = { + "status": health_status, "model_loaded": _engine is not None, "model_name": _model_name, - "model_type": "mllm" if (_engine and _engine.is_mllm) else "llm", + "model_type": ( + "mllm" + if (_engine and _engine.is_mllm) + or _force_mllm_model + or ( + _engine is None + and (_model_path or _model_name) + and is_mllm_model(_model_path or _model_name) + ) + else "llm" + ), + "engine_type": engine_stats.get("engine_type", "unknown"), "mcp": mcp_info, } + if lifecycle is not None: + lifecycle_fields = { + "residency_state": lifecycle["state"], + "active_requests": lifecycle["active_requests"], + "last_used_at": lifecycle["last_used_at"], + "loaded_at": lifecycle["loaded_at"], + "auto_unload_idle_seconds": lifecycle["auto_unload_idle_seconds"], + } + if lifecycle.get("state") == "failed": + lifecycle_fields["last_error"] = ( + "model_load_failed" if lifecycle.get("last_error") is not None else None + ) + payload.update(lifecycle_fields) + return payload -@app.get("/v1/status") +@app.get("/v1/status", dependencies=[Depends(verify_api_key)]) async def status(): """Real-time status with per-request details for debugging and monitoring.""" + lifecycle = _public_lifecycle_status(_get_lifecycle_status()) if _engine is None: - return {"status": "not_loaded", "model": None, "requests": []} + return { + "status": "not_loaded", + "model": _model_name, + "residency": lifecycle, + "requests": [], + } stats = _engine.get_stats() return { "status": "running" if stats.get("running") else "stopped", "model": _model_name, + "residency": lifecycle, "uptime_s": round(stats.get("uptime_seconds", 0), 1), "steps_executed": stats.get("steps_executed", 0), "num_running": stats.get("num_running", 0), @@ -775,9 +2708,16 @@ async def status(): } -@app.get("/v1/cache/stats") +@app.get("/v1/cache/stats", dependencies=[Depends(verify_api_key)]) async def cache_stats(): """Get cache statistics for debugging and monitoring.""" + engine_cache = None + if _engine is not None and hasattr(_engine, "get_cache_stats"): + try: + engine_cache = _engine.get_cache_stats() + except Exception as exc: + engine_cache = {"error": f"engine cache stats failed: {exc}"} + try: from mlx_vlm.utils import ( get_multimodal_kv_cache_stats, @@ -786,17 +2726,29 @@ async def cache_stats(): ) return { + "engine_cache": engine_cache, "multimodal_kv_cache": get_multimodal_kv_cache_stats(), "pixel_values_cache": get_pixel_values_cache_stats(), "pil_image_cache": get_pil_cache_stats(), } except ImportError: - return {"error": "Cache stats not available (mlx_vlm not loaded)"} + return { + "engine_cache": engine_cache, + "error": "Cache stats not available (mlx_vlm not loaded)", + } -@app.delete("/v1/cache") +@app.delete("/v1/cache", dependencies=[Depends(verify_api_key)]) async def clear_cache(): """Clear all caches.""" + cleared_engine = None + if _engine is not None and hasattr(_engine, "clear_runtime_caches"): + try: + cleared_engine = _engine.clear_runtime_caches() + except Exception as exc: + logger.warning("Failed to clear engine caches: %s", exc, exc_info=True) + cleared_engine = {"error": str(exc)} + try: from mlx_vlm.utils import ( clear_multimodal_kv_cache, @@ -807,10 +2759,69 @@ async def clear_cache(): clear_pixel_values_cache() return { "status": "cleared", + "engine_cache": cleared_engine, "caches": ["multimodal_kv", "pixel_values", "pil_image"], } except ImportError: - return {"error": "Cache clear not available (mlx_vlm not loaded)"} + return { + "status": "cleared", + "engine_cache": cleared_engine, + "error": "Cache clear not available (mlx_vlm not loaded)", + } + + +@app.delete("/v1/cache/prefix", dependencies=[Depends(verify_api_key)]) +async def clear_prefix_cache(): + """Clear the text prefix cache used for KV reuse in continuous batching. + + If the server was started with ``--warm-prompts``, the warm-up is + re-run in the background after clear so the next real request still + hits the cache. Response returns immediately without waiting for + the re-warm to finish. + """ + if _engine is None: + return {"status": "no_engine"} + cleared = False + if hasattr(_engine, "clear_prefix_cache"): + try: + _engine.clear_prefix_cache() + cleared = True + except Exception as e: + logger.warning( + "[clear_prefix_cache] engine.clear_prefix_cache failed: %s", + _sanitize_log_text(e, limit=500), + ) + + # Auto re-warm in background if warm-prompts was configured. + rewarm_scheduled = False + if cleared and _warm_prompts_path and hasattr(_engine, "stream_chat"): + + async def _rewarm(): + try: + from vllm_mlx.prompt_warmup import ( + load_warmup_file, + warm_prefix_cache, + ) + + prompts = load_warmup_file(_warm_prompts_path) + result = await warm_prefix_cache(_engine, prompts) + logger.info( + "[clear_prefix_cache] re-warm done: %d completed, %d skipped, %.1fs", + result["count"], + result["skipped"], + result["elapsed_ms"] / 1000, + ) + except Exception as e: + logger.warning( + "[clear_prefix_cache] re-warm failed: %s", + _sanitize_log_text(e, limit=500), + ) + + asyncio.create_task(_rewarm()) + rewarm_scheduled = True + + status = "cleared" if cleared else "not_supported" + return {"status": status, "rewarm_scheduled": rewarm_scheduled} @app.get("/v1/models", dependencies=[Depends(verify_api_key)]) @@ -819,6 +2830,14 @@ async def list_models() -> ModelsResponse: models = [] if _model_name: models.append(ModelInfo(id=_model_name)) + if _embedding_engine is not None: + models.append( + ModelInfo(id=_embedding_engine.model_name, owned_by="vllm-mlx-embedding") + ) + if _rerank_engine is not None: + models.append( + ModelInfo(id=_rerank_engine.model_name, owned_by="vllm-mlx-reranker") + ) return ModelsResponse(data=models) @@ -871,33 +2890,27 @@ async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse: } ``` - Supported models: + Supported request-time models: - mlx-community/all-MiniLM-L6-v2-4bit (fast, compact) - mlx-community/embeddinggemma-300m-6bit (high quality) - mlx-community/bge-large-en-v1.5-4bit (best for English) - - Any BERT/XLM-RoBERTa/ModernBERT model from HuggingFace + - mlx-community/multilingual-e5-small-mlx + - mlx-community/multilingual-e5-large-mlx + - mlx-community/bert-base-uncased-mlx + - mlx-community/ModernBERT-base-mlx + + Other embedding models must be pinned explicitly with --embedding-model at + server startup. """ global _embedding_engine tracker = _metrics.track_inference("embeddings", stream=False) try: - # Resolve model name - model_name = request.model - - # If an embedding model was pre-configured at startup, only allow that model - if ( - _embedding_model_locked is not None - and model_name != _embedding_model_locked - ): - raise HTTPException( - status_code=400, - detail=( - f"Embedding model '{model_name}' is not available. " - f"This server was started with --embedding-model {_embedding_model_locked}. " - f"Only '{_embedding_model_locked}' can be used for embeddings. " - f"Restart the server with a different --embedding-model to use '{model_name}'." - ), - ) + # Resolve model name before any lazy-load path is reached. + model_name = resolve_embedding_model_name( + request.model, + locked_model=_embedding_model_locked, + ) # Lazy-load or swap embedding engine load_embedding_model(model_name, lock=False, reuse_existing=True) @@ -918,43 +2931,176 @@ async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse: elapsed = time.perf_counter() - start_time logger.info( - f"Embeddings: {len(texts)} inputs, {prompt_tokens} tokens in {elapsed:.2f}s" + f"Embeddings: {len(texts)} inputs, {prompt_tokens} tokens in {elapsed:.2f}s" + ) + + # Build OpenAI-compatible response with ordered indices + data = [ + EmbeddingData(index=i, embedding=vec) for i, vec in enumerate(embeddings) + ] + + response = EmbeddingResponse( + data=data, + model=model_name, + usage=EmbeddingUsage( + prompt_tokens=prompt_tokens, + total_tokens=prompt_tokens, + ), + ) + tracker.finish( + result="success", + prompt_tokens=prompt_tokens, + completion_tokens=0, + ) + return response + + except ImportError: + tracker.finish(result="error") + raise HTTPException( + status_code=503, + detail=( + "mlx-embeddings not installed. Install with: pip install mlx-embeddings" + ), + ) + except HTTPException as exc: + tracker.finish(result=_metrics_result_from_status(exc.status_code)) + raise + except Exception as e: + tracker.finish(result="error") + _log_and_raise_internal_error( + "Embedding generation failed", + e, + "Embedding generation failed", + ) + + +# ============================================================================= +# Reranking Endpoint +# ============================================================================= + + +@app.post( + "/v1/rerank", + dependencies=[Depends(verify_api_key), Depends(check_rate_limit)], +) +async def rerank_documents(request: RerankRequest) -> RerankResponse: + """ + Rerank documents against a query using a cross-encoder model. + + Jina/Cohere-compatible reranking API. Accepts a query and a list of + documents (strings or {text: ...} objects), returns results sorted + by relevance score descending. + """ + global _rerank_engine + + try: + model_name = request.model + + # If a reranker model was pre-configured at startup, only allow that model + if _rerank_model_locked is not None and model_name != _rerank_model_locked: + raise HTTPException( + status_code=400, + detail=( + f"Reranker model '{model_name}' is not available. " + f"This server was started with --rerank-model {_rerank_model_locked}. " + f"Only '{_rerank_model_locked}' can be used for reranking. " + f"Restart the server with a different --rerank-model to use '{model_name}'." + ), + ) + + # Validate query + if not request.query or not request.query.strip(): + raise HTTPException(status_code=400, detail="Query must not be empty") + + # Validate documents + if not request.documents: + raise HTTPException( + status_code=400, detail="Documents list must not be empty" + ) + + # Validate top_n + if request.top_n is not None and request.top_n > len(request.documents): + raise HTTPException( + status_code=400, + detail=( + f"top_n ({request.top_n}) must not exceed the number of " + f"documents ({len(request.documents)})" + ), + ) + + # Require --rerank-model at startup; no unconstrained lazy loading + if _rerank_engine is None: + raise HTTPException( + status_code=404, + detail=( + "No reranker model loaded. Start the server with " + "--rerank-model to enable the /v1/rerank endpoint." + ), + ) + + # Extract text from documents (handle both string and object formats) + doc_texts = [] + original_docs = [] + for doc in request.documents: + if isinstance(doc, str): + doc_texts.append(doc) + original_docs.append({"text": doc}) + elif isinstance(doc, dict) and "text" in doc: + doc_texts.append(doc["text"]) + original_docs.append(doc) + else: + raise HTTPException( + status_code=400, + detail=( + f"Each document must be a string or an object with a 'text' field. " + f"Got: {type(doc).__name__}" + ), + ) + + start_time = time.perf_counter() + + # Run scoring off the event loop with concurrency limit. + # score_pairs returns (scores, total_tokens) from the same + # tokenization pass used for scoring — no double tokenization. + import asyncio + + async with _rerank_engine._semaphore: + scores, total_tokens = await asyncio.to_thread( + _rerank_engine.score_pairs, request.query, doc_texts + ) + + elapsed = time.perf_counter() - start_time + logger.info( + f"Rerank: {len(doc_texts)} documents, {total_tokens} tokens in {elapsed:.2f}s" ) - # Build OpenAI-compatible response with ordered indices - data = [ - EmbeddingData(index=i, embedding=vec) for i, vec in enumerate(embeddings) - ] + # Build results with original index and optional document + results = [] + for i, score in enumerate(scores): + result = RerankResult( + index=i, + relevance_score=score, + document=original_docs[i] if request.return_documents else None, + ) + results.append(result) - response = EmbeddingResponse( - data=data, + # Sort by relevance score descending + results.sort(key=lambda r: r.relevance_score, reverse=True) + + # Apply top_n limit + if request.top_n is not None: + results = results[: request.top_n] + + return RerankResponse( model=model_name, - usage=EmbeddingUsage( - prompt_tokens=prompt_tokens, - total_tokens=prompt_tokens, - ), - ) - tracker.finish( - result="success", - prompt_tokens=prompt_tokens, - completion_tokens=0, + results=results, + usage=RerankUsage(total_tokens=total_tokens), ) - return response - except ImportError: - tracker.finish(result="error") - raise HTTPException( - status_code=503, - detail=( - "mlx-embeddings not installed. Install with: pip install mlx-embeddings" - ), - ) - except HTTPException as exc: - tracker.finish(result=_metrics_result_from_status(exc.status_code)) + except HTTPException: raise except Exception as e: - tracker.finish(result="error") - logger.error(f"Embedding generation failed: {e}") + logger.error(f"Reranking failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -1007,15 +3153,27 @@ async def list_mcp_servers() -> MCPServersResponse: @app.post("/v1/mcp/execute", dependencies=[Depends(verify_api_key)]) async def execute_mcp_tool(request: MCPExecuteRequest) -> MCPExecuteResponse: """Execute an MCP tool.""" + global _mcp_executor + if _mcp_manager is None: raise HTTPException( status_code=503, detail="MCP not configured. Start server with --mcp-config" ) - result = await _mcp_manager.execute_tool( - request.tool_name, - request.arguments, - ) + if _mcp_executor is None: + from vllm_mlx.mcp import ToolExecutor + + _mcp_executor = ToolExecutor(_mcp_manager) + + tool_call = { + "id": f"mcp-{uuid.uuid4().hex[:8]}", + "type": "function", + "function": { + "name": request.tool_name, + "arguments": request.arguments, + }, + } + result, _ = (await _mcp_executor.execute_tool_calls([tool_call], parallel=False))[0] return MCPExecuteResponse( tool_name=result.tool_name, @@ -1056,27 +3214,19 @@ async def create_transcription( try: from .audio.stt import STTEngine # Lazy import - optional feature - # Map model aliases to full names - model_map = { - "whisper-large-v3": "mlx-community/whisper-large-v3-mlx", - "whisper-large-v3-turbo": "mlx-community/whisper-large-v3-turbo", - "whisper-medium": "mlx-community/whisper-medium-mlx", - "whisper-small": "mlx-community/whisper-small-mlx", - "parakeet": "mlx-community/parakeet-tdt-0.6b-v2", - "parakeet-v3": "mlx-community/parakeet-tdt-0.6b-v3", - } - model_name = model_map.get(model, model) + model_name = resolve_stt_model_name(model) # Load engine if needed if _stt_engine is None or _stt_engine.model_name != model_name: _stt_engine = STTEngine(model_name) _stt_engine.load() - # Save uploaded file temporarily - with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: - content = await file.read() - tmp.write(content) - tmp_path = tmp.name + # Stream uploaded file to disk under a hard size cap. + tmp_path = await save_upload_with_limit( + file, + max_bytes=_max_audio_upload_bytes, + default_suffix=".wav", + ) try: result = _stt_engine.transcribe(tmp_path, language=language) @@ -1105,8 +3255,11 @@ async def create_transcription( raise except Exception as e: tracker.finish(result="error") - logger.error(f"Transcription failed: {e}") - raise HTTPException(status_code=500, detail=str(e)) + _log_and_raise_internal_error( + "Transcription failed", + e, + "Transcription failed", + ) @app.post("/v1/audio/speech", dependencies=[Depends(verify_api_key)]) @@ -1132,16 +3285,8 @@ async def create_speech( try: from .audio.tts import TTSEngine # Lazy import - optional feature - # Map model aliases to full names - model_map = { - "kokoro": "mlx-community/Kokoro-82M-bf16", - "kokoro-4bit": "mlx-community/Kokoro-82M-4bit", - "chatterbox": "mlx-community/chatterbox-turbo-fp16", - "chatterbox-4bit": "mlx-community/chatterbox-turbo-4bit", - "vibevoice": "mlx-community/VibeVoice-Realtime-0.5B-4bit", - "voxcpm": "mlx-community/VoxCPM1.5", - } - model_name = model_map.get(model, model) + model_name = resolve_tts_model_name(model) + validate_tts_input_length(input, max_chars=_max_tts_input_chars) # Load engine if needed if _tts_engine is None or _tts_engine.model_name != model_name: @@ -1168,8 +3313,11 @@ async def create_speech( raise except Exception as e: tracker.finish(result="error") - logger.error(f"TTS generation failed: {e}") - raise HTTPException(status_code=500, detail=str(e)) + _log_and_raise_internal_error( + "TTS generation failed", + e, + "Speech generation failed", + ) @app.get("/v1/audio/voices", dependencies=[Depends(verify_api_key)]) @@ -1190,11 +3338,36 @@ async def list_voices(model: str = "kokoro"): # ============================================================================= +async def _ensure_sse_terminal( + generator: AsyncIterator[str], + terminal_frame: str, +) -> AsyncIterator[str]: + """Guarantee that *terminal_frame* is emitted exactly once at the end of + *generator*, even if the generator raises mid-stream. + + If the inner generator already yields the terminal frame on its happy path, + the wrapper detects it and avoids double-emission. If the generator raises + before reaching the terminal, the wrapper emits it in the ``finally`` block. + """ + emitted = False + try: + async for chunk in generator: + if chunk == terminal_frame: + emitted = True + yield chunk + except Exception as e: + logger.error(f"Streaming error, ensuring terminal frame: {e}") + finally: + if not emitted: + yield terminal_frame + + async def _disconnect_guard( generator: AsyncIterator[str], raw_request: Request, poll_interval: float = 0.5, heartbeat_interval: float = 5.0, + cleanup=None, ) -> AsyncIterator[str]: """Wrap streaming generator to abort on client disconnect. @@ -1273,6 +3446,12 @@ async def _wait_disconnect(): f"{chunk_count} chunks, elapsed={_elapsed()}" ) break + except Exception as exc: + logger.error( + f"[disconnect_guard] generator raised {type(exc).__name__}: {exc}, " + f"after {chunk_count} chunks, elapsed={_elapsed()}" + ) + break chunk_count += 1 if chunk_count == 1: logger.info( @@ -1307,6 +3486,10 @@ async def _wait_disconnect(): # Instead, rely on the task cancellation propagation: # anext_task.cancel() → CancelledError in stream_outputs() # → finally block → abort_request() → request removed from scheduler + if cleanup is not None: + result = cleanup() + if asyncio.iscoroutine(result): + await result logger.info( f"[disconnect_guard] CLEANUP done, {chunk_count} chunks, " f"{heartbeat_count} heartbeats, elapsed={_elapsed()}" @@ -1318,6 +3501,8 @@ async def _wait_with_disconnect( raw_request: Request, timeout: float, poll_interval: float = 0.5, + timeout_detail_seconds: float | None = None, + cleanup_result=None, ): """Run a coroutine with both timeout and client disconnect detection. @@ -1363,7 +3548,10 @@ async def _wait_disconnect(): pass raise HTTPException( status_code=504, - detail=f"Request timed out after {timeout:.1f} seconds", + detail=( + "Request timed out after " + f"{(timeout_detail_seconds or timeout):.1f} seconds" + ), ) if disconnect_task in done: @@ -1372,11 +3560,22 @@ async def _wait_disconnect(): f"[disconnect_guard] CLIENT DISCONNECTED (non-stream) " f"elapsed={_time.monotonic() - _t0:.1f}s" ) - task.cancel() - try: - await task - except (asyncio.CancelledError, Exception): - pass + if task in done: + try: + result = task.result() + except (asyncio.CancelledError, Exception): + pass + else: + if cleanup_result is not None: + cleanup = cleanup_result(result) + if asyncio.iscoroutine(cleanup): + await cleanup + else: + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass return None # Signal to caller that client disconnected # Task completed @@ -1389,6 +3588,50 @@ async def _wait_disconnect(): task.cancel() +def _start_request_budget(timeout: float | None) -> tuple[float, float]: + """Return the total timeout and absolute deadline for a request.""" + total_timeout = timeout or _default_timeout + return total_timeout, time.monotonic() + total_timeout + + +def _remaining_request_timeout(total_timeout: float, deadline: float) -> float: + """Compute remaining request budget or raise the standard timeout error.""" + remaining = deadline - time.monotonic() + if remaining <= 0: + raise HTTPException( + status_code=504, + detail=f"Request timed out after {total_timeout:.1f} seconds", + ) + return remaining + + +async def _acquire_default_engine_for_request( + raw_request: Request, + *, + total_timeout: float, + deadline: float, + count_activity: bool = True, +) -> BaseEngine | None: + """Acquire the default engine inside the request guardrails.""" + if count_activity: + acquire_coro = _acquire_default_engine() + cleanup = lambda _result: _release_default_engine() + else: + acquire_coro = _acquire_default_engine(count_activity=False) + cleanup = lambda _result: _release_default_engine(count_activity=False) + + if raw_request is None: + return await acquire_coro + + return await _wait_with_disconnect( + acquire_coro, + raw_request, + timeout=_remaining_request_timeout(total_timeout, deadline), + timeout_detail_seconds=total_timeout, + cleanup_result=cleanup, + ) + + # ============================================================================= # Completion Endpoints # ============================================================================= @@ -1400,11 +3643,12 @@ async def _wait_disconnect(): async def create_completion(request: CompletionRequest, raw_request: Request): """Create a text completion.""" _validate_model_name(request.model) - engine = get_engine() + effective_max_tokens = _resolve_request_max_tokens(request.max_tokens) tracker = _metrics.track_inference("completions", stream=request.stream) # Handle single prompt or list of prompts prompts = request.prompt if isinstance(request.prompt, list) else [request.prompt] + total_timeout, deadline = _start_request_budget(request.timeout) # --- Detailed request logging --- prompt_preview = prompts[0][:200] if prompts else "(empty)" @@ -1415,101 +3659,123 @@ async def create_completion(request: CompletionRequest, raw_request: Request): f"top_p={request.top_p} top_k={request.top_k} min_p={request.min_p} " f"presence_penalty={request.presence_penalty} " f"repetition_penalty={request.repetition_penalty} " - f"prompt_chars={prompt_len} prompt_preview={prompt_preview!r}" + f"prompt_chars={prompt_len} " + f"prompt_preview={_sanitize_log_text(prompt_preview, limit=200)}" ) # Resolve repetition penalty for completions comp_rep_penalty = request.repetition_penalty - if request.stream: - return StreamingResponse( - _disconnect_guard( - stream_completion( - engine, - prompts[0], - request, - repetition_penalty=comp_rep_penalty, - metrics_tracker=tracker, + engine = await _acquire_default_engine_for_request( + raw_request, + total_timeout=total_timeout, + deadline=deadline, + ) + if engine is None: + return Response(status_code=499) + release_on_exit = True + + try: + if request.stream: + response = StreamingResponse( + _disconnect_guard( + _ensure_sse_terminal( + stream_completion( + engine, + prompts[0], + request, + effective_max_tokens, + repetition_penalty=comp_rep_penalty, + metrics_tracker=tracker, + ), + "data: [DONE]\n\n", + ), + raw_request, + cleanup=_release_default_engine, ), - raw_request, - ), - media_type="text/event-stream", - ) + media_type="text/event-stream", + ) + release_on_exit = False + return response - # Non-streaming response with timing and timeout - start_time = time.perf_counter() - timeout = request.timeout or _default_timeout - choices = [] - total_completion_tokens = 0 - total_prompt_tokens = 0 - - for i, prompt in enumerate(prompts): - generate_kwargs = { - "prompt": prompt, - "max_tokens": request.max_tokens or _default_max_tokens, - "temperature": _resolve_temperature(request.temperature), - "top_p": _resolve_top_p(request.top_p), - "top_k": request.top_k or 0, - "min_p": request.min_p or 0.0, - "presence_penalty": request.presence_penalty or 0.0, - "stop": request.stop, - } - if comp_rep_penalty is not None: - generate_kwargs["repetition_penalty"] = comp_rep_penalty - if request.specprefill is not None: - generate_kwargs["specprefill"] = request.specprefill - if request.specprefill_keep_pct is not None: - generate_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct + # Non-streaming response with timing and timeout + start_time = time.perf_counter() + choices = [] + total_completion_tokens = 0 + total_prompt_tokens = 0 + for i, prompt in enumerate(prompts): + generate_kwargs = { + "prompt": prompt, + "max_tokens": effective_max_tokens, + "temperature": _resolve_temperature(request.temperature), + "top_p": _resolve_top_p(request.top_p), + "top_k": request.top_k or 0, + "min_p": request.min_p or 0.0, + "presence_penalty": request.presence_penalty or 0.0, + "stop": request.stop, + } + if comp_rep_penalty is not None: + generate_kwargs["repetition_penalty"] = comp_rep_penalty + if request.specprefill is not None: + generate_kwargs["specprefill"] = request.specprefill + if request.specprefill_keep_pct is not None: + generate_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct + try: + if raw_request is None: + output = await engine.generate(**generate_kwargs) + else: + output = await _wait_with_disconnect( + engine.generate(**generate_kwargs), + raw_request, + timeout=_remaining_request_timeout(total_timeout, deadline), + timeout_detail_seconds=total_timeout, + ) + except HTTPException as exc: + tracker.finish(result=_metrics_result_from_status(exc.status_code)) + raise + if output is None: + tracker.finish( + result="client_closed", + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + ) + return Response(status_code=499) # Client closed request - try: - output = await _wait_with_disconnect( - engine.generate(**generate_kwargs), - raw_request, - timeout=timeout, + choices.append( + CompletionChoice( + index=i, + text=output.text, + finish_reason=output.finish_reason, + ) ) - except HTTPException as exc: - tracker.finish(result=_metrics_result_from_status(exc.status_code)) - raise - if output is None: - tracker.finish( - result="client_closed", - prompt_tokens=total_prompt_tokens, - completion_tokens=total_completion_tokens, + total_completion_tokens += output.completion_tokens + total_prompt_tokens += ( + output.prompt_tokens if hasattr(output, "prompt_tokens") else 0 ) - return Response(status_code=499) # Client closed request - choices.append( - CompletionChoice( - index=i, - text=output.text, - finish_reason=output.finish_reason, - ) - ) - total_completion_tokens += output.completion_tokens - total_prompt_tokens += ( - output.prompt_tokens if hasattr(output, "prompt_tokens") else 0 + elapsed = time.perf_counter() - start_time + tokens_per_sec = total_completion_tokens / elapsed if elapsed > 0 else 0 + logger.info( + f"Completion: {total_prompt_tokens} prompt + {total_completion_tokens} completion tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" ) - elapsed = time.perf_counter() - start_time - tokens_per_sec = total_completion_tokens / elapsed if elapsed > 0 else 0 - logger.info( - f"Completion: {total_prompt_tokens} prompt + {total_completion_tokens} completion tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" - ) - - tracker.finish( - result="success", - prompt_tokens=total_prompt_tokens, - completion_tokens=total_completion_tokens, - ) - return CompletionResponse( - model=_model_name, - choices=choices, - usage=Usage( + tracker.finish( + result="success", prompt_tokens=total_prompt_tokens, completion_tokens=total_completion_tokens, - total_tokens=total_prompt_tokens + total_completion_tokens, - ), - ) + ) + return CompletionResponse( + model=_model_name, + choices=choices, + usage=Usage( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens, + ), + ) + finally: + if release_on_exit: + await _release_default_engine() @app.post( @@ -1528,267 +3794,185 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re {"type": "text", "text": "What's in this image?"}, {"type": "image_url", "image_url": {"url": "https://..."}} ] - }] - ``` - - Video support: - ```json - messages=[{ - "role": "user", - "content": [ - {"type": "text", "text": "What happens in this video?"}, - {"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}} - ] - }] - ``` - - Structured output (JSON mode): - ```json - response_format={"type": "json_object"} - ``` - - Structured output (JSON Schema): - ```json - response_format={ - "type": "json_schema", - "json_schema": { - "name": "my_schema", - "schema": {"type": "object", "properties": {...}} - } - } - ``` - """ - _validate_model_name(request.model) - engine = get_engine() - tracker = _metrics.track_inference("chat_completions", stream=request.stream) - - # --- Detailed request logging --- - n_msgs = len(request.messages) - msg_roles = [m.role for m in request.messages] - total_chars = 0 - last_user_preview = "" - for m in request.messages: - content = m.content if isinstance(m.content, str) else str(m.content) - total_chars += len(content) - if m.role == "user": - last_user_preview = content[:300] - has_tools = bool(request.tools) - n_tools = len(request.tools) if request.tools else 0 - logger.info( - f"[REQUEST] POST /v1/chat/completions stream={request.stream} " - f"model={request.model!r} max_tokens={request.max_tokens} " - f"temp={request.temperature} top_p={request.top_p} " - f"top_k={request.top_k} min_p={request.min_p} " - f"presence_penalty={request.presence_penalty} " - f"repetition_penalty={request.repetition_penalty} " - f"msgs={n_msgs} roles={msg_roles} " - f"total_chars={total_chars} tools={n_tools} " - f"response_format={request.response_format}" - ) - logger.info(f"[REQUEST] last user message preview: {last_user_preview!r}") - - # For MLLM models, keep original messages with embedded images - # (MLLM.chat() extracts images from message content internally) - if engine.is_mllm: - # Convert Pydantic messages to dicts, excluding None fields - # to prevent chat templates from misinterpreting key presence - # (e.g. image_url: null on text parts triggers Qwen3-VL crash) - messages = [] - for msg in request.messages: - if hasattr(msg, "model_dump"): - msg_dict = msg.model_dump(exclude_none=True) - else: - raw = dict(msg) - msg_dict = {k: v for k, v in raw.items() if v is not None} - messages.append(msg_dict) - images, videos = [], [] # MLLM extracts these from messages - logger.debug(f"MLLM: Processing {len(messages)} messages") - # Convert tool_call arguments from JSON string to dict so that - # chat templates can iterate them (e.g. GLM-4.6V calls .items()). - # The LLM path does this inside extract_multimodal_content(), but - # the MLLM path bypasses that function. - if engine.preserve_native_tool_format: - for msg_dict in messages: - for tc in msg_dict.get("tool_calls") or []: - func = tc.get("function") or {} - args = func.get("arguments") - if isinstance(args, str): - try: - func["arguments"] = json.loads(args) - except (json.JSONDecodeError, ValueError): - pass - messages = _normalize_messages(messages) - else: - # For LLM, extract text, images, and videos separately - messages, images, videos = extract_multimodal_content( - request.messages, - preserve_native_format=engine.preserve_native_tool_format, - ) - messages = _normalize_messages(messages) - - has_media = bool(images or videos) - if engine.is_mllm and not has_media: - # MLLM extracts media from messages directly, so images/videos are - # always empty. Check message content for video/image types instead. - for msg in request.messages: - content = msg.content if hasattr(msg, "content") else msg.get("content", "") - if isinstance(content, list): - for item in content: - item_type = ( - item.type - if hasattr(item, "type") - else (item.get("type", "") if isinstance(item, dict) else "") - ) - if item_type in ("image_url", "image", "video", "video_url"): - has_media = True - break - if has_media: - break - - # Handle response_format - inject system prompt if needed - response_format = request.response_format - if response_format: - json_instruction = build_json_system_prompt(response_format) - if json_instruction: - # Inject JSON instruction into messages - messages = _inject_json_instruction(messages, json_instruction) - - # Resolve repetition penalty - rep_penalty = request.repetition_penalty - - # Prepare kwargs - chat_kwargs = { - "max_tokens": request.max_tokens or _default_max_tokens, - "temperature": _resolve_temperature(request.temperature), - "top_p": _resolve_top_p(request.top_p), - "top_k": request.top_k or 0, - "min_p": request.min_p or 0.0, - "presence_penalty": request.presence_penalty or 0.0, - "repetition_penalty": request.repetition_penalty or 1.0, - } - if rep_penalty is not None: - chat_kwargs["repetition_penalty"] = rep_penalty - - # Add multimodal content - if has_media: - chat_kwargs["images"] = images if images else None - chat_kwargs["videos"] = videos if videos else None - if request.video_fps: - chat_kwargs["video_fps"] = request.video_fps - if request.video_max_frames: - chat_kwargs["video_max_frames"] = request.video_max_frames + }] + ``` - # SpecPrefill: per-request overrides - if request.specprefill is not None: - chat_kwargs["specprefill"] = request.specprefill - if request.specprefill_keep_pct is not None: - chat_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct + Video support: + ```json + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What happens in this video?"}, + {"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}} + ] + }] + ``` - # Enable/disable thinking mode per request - if request.enable_thinking is not None: - chat_kwargs["enable_thinking"] = request.enable_thinking + Structured output (JSON mode): + ```json + response_format={"type": "json_object"} + ``` - # Add tools if provided - if request.tools and request.tool_choice != "none": - chat_kwargs["tools"] = convert_tools_for_template(request.tools) + Structured output (JSON Schema): + ```json + response_format={ + "type": "json_schema", + "json_schema": { + "name": "my_schema", + "schema": {"type": "object", "properties": {...}} + } + } + ``` + """ + _validate_model_name(request.model) + effective_max_tokens = _resolve_request_max_tokens(request.max_tokens) + tracker = _metrics.track_inference("chat_completions", stream=request.stream) + total_timeout, deadline = _start_request_budget(request.timeout) - if request.stream: - return StreamingResponse( - _disconnect_guard( - stream_chat_completion( - engine, - messages, - request, - metrics_tracker=tracker, - **chat_kwargs, - ), - raw_request, - ), - media_type="text/event-stream", - ) + # --- Detailed request logging --- + n_msgs = len(request.messages) + msg_roles = [m.role for m in request.messages] + total_chars = 0 + last_user_preview = "" + for m in request.messages: + content = m.content if isinstance(m.content, str) else str(m.content) + total_chars += len(content) + if m.role == "user": + last_user_preview = content[:300] + n_tools = len(request.tools) if request.tools else 0 + logger.info( + f"[REQUEST] POST /v1/chat/completions stream={request.stream} " + f"model={request.model!r} max_tokens={request.max_tokens} " + f"temp={request.temperature} top_p={request.top_p} " + f"top_k={request.top_k} min_p={request.min_p} " + f"presence_penalty={request.presence_penalty} " + f"repetition_penalty={request.repetition_penalty} " + f"msgs={n_msgs} roles={msg_roles} " + f"total_chars={total_chars} tools={n_tools} " + f"response_format={request.response_format}" + ) + logger.info( + "[REQUEST] last user message preview: %s", + _sanitize_log_text(last_user_preview, limit=300), + ) - # Non-streaming response with timing and timeout - start_time = time.perf_counter() - timeout = request.timeout or _default_timeout + engine = await _acquire_default_engine_for_request( + raw_request, + total_timeout=total_timeout, + deadline=deadline, + ) + if engine is None: + return Response(status_code=499) + release_on_exit = True try: - output = await _wait_with_disconnect( - engine.chat(messages=messages, **chat_kwargs), - raw_request, - timeout=timeout, + prepared = _prepare_chat_completion_invocation( + engine, + request, + effective_max_tokens, ) - except HTTPException as exc: - tracker.finish(result=_metrics_result_from_status(exc.status_code)) - raise - if output is None: - tracker.finish(result="client_closed") - return Response(status_code=499) # Client closed request - elapsed = time.perf_counter() - start_time - tokens_per_sec = output.completion_tokens / elapsed if elapsed > 0 else 0 - logger.info( - f"Chat completion: {output.completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" - ) + if request.stream: + response = StreamingResponse( + _disconnect_guard( + _ensure_sse_terminal( + stream_chat_completion( + engine, + prepared.messages, + request, + metrics_tracker=tracker, + **prepared.chat_kwargs, + ), + "data: [DONE]\n\n", + ), + raw_request, + cleanup=_release_default_engine, + ), + media_type="text/event-stream", + ) + release_on_exit = False + return response - # Parse tool calls from output using configured parser - # Skip tool parsing when request has no tools — otherwise the parser - # can misinterpret JSON output (e.g. response_format) as tool calls. - if request.tools: - cleaned_text, tool_calls = _parse_tool_calls_with_parser(output.text, request) - else: - cleaned_text, tool_calls = output.text, None + start_time = time.perf_counter() - # Extract reasoning content (strips channel tokens before JSON extraction) - # Skip reasoning parser when enable_thinking=False (no think tags expected) - reasoning_text = None - if _reasoning_parser and request.enable_thinking is not False: - # Always use original output.text for reasoning extraction so - # content is preserved even when tool calls are present. - text_to_parse = output.text - reasoning_text, remaining_text = _reasoning_parser.extract_reasoning( - text_to_parse + try: + output = await _wait_with_disconnect( + engine.chat(messages=prepared.messages, **prepared.chat_kwargs), + raw_request, + timeout=_remaining_request_timeout(total_timeout, deadline), + timeout_detail_seconds=total_timeout, + ) + except HTTPException as exc: + tracker.finish(result=_metrics_result_from_status(exc.status_code)) + raise + if output is None: + tracker.finish(result="client_closed") + return Response(status_code=499) # Client closed request + + elapsed = time.perf_counter() - start_time + tokens_per_sec = output.completion_tokens / elapsed if elapsed > 0 else 0 + logger.info( + f"Chat completion: {output.completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" ) - # Only update cleaned_text from reasoning parser when no tool calls - # (tool parser already set cleaned_text appropriately) - if not tool_calls: - cleaned_text = remaining_text - # Process response_format if specified (after reasoning parser cleaned the text) - if response_format and not tool_calls: - json_input = cleaned_text or output.text - _, parsed_json, is_valid, error = parse_json_output(json_input, response_format) - if parsed_json is not None: - # Return JSON as string - cleaned_text = json.dumps(parsed_json) - if not is_valid: - logger.warning(f"JSON validation failed: {error}") - - # Determine finish reason - finish_reason = "tool_calls" if tool_calls else output.finish_reason - - tracker.finish( - result="success", - prompt_tokens=output.prompt_tokens, - completion_tokens=output.completion_tokens, - ) - return ChatCompletionResponse( - model=_model_name, - choices=[ - ChatCompletionChoice( - message=AssistantMessage( - content=clean_output_text(cleaned_text) if cleaned_text else None, - reasoning=reasoning_text, - tool_calls=tool_calls, - ), - finish_reason=finish_reason, + reasoning_text, cleaned_text, tool_calls = _extract_reasoning_and_tool_calls( + output.text, + request, + allow_reasoning=( + getattr(request, "enable_thinking", None) is not False + and prepared.json_logits_processor is None + ), + engine=engine, + ) + + # Process response_format if specified (after reasoning parser cleaned the text) + if prepared.response_format and not tool_calls: + json_input = cleaned_text or output.text + _, parsed_json, is_valid, error = parse_json_output( + json_input, prepared.response_format ) - ], - usage=Usage( + if parsed_json is not None: + # Return JSON as string + cleaned_text = json.dumps(parsed_json) + if not is_valid: + if prepared.json_logits_processor is not None: + logger.error( + "Constrained decoding produced invalid JSON: %s", error + ) + else: + logger.warning(f"JSON validation failed: {error}") + + # Determine finish reason + finish_reason = "tool_calls" if tool_calls else output.finish_reason + + tracker.finish( + result="success", prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens, - total_tokens=output.prompt_tokens + output.completion_tokens, - ), - ) + ) + return ChatCompletionResponse( + model=_model_name, + choices=[ + ChatCompletionChoice( + message=AssistantMessage( + content=( + clean_output_text(cleaned_text) if cleaned_text else None + ), + reasoning=reasoning_text, + tool_calls=tool_calls, + ), + finish_reason=finish_reason, + ) + ], + usage=Usage( + prompt_tokens=output.prompt_tokens, + completion_tokens=output.completion_tokens, + total_tokens=output.prompt_tokens + output.completion_tokens, + ), + ) + finally: + if release_on_exit: + await _release_default_engine() def _normalize_messages(messages: list[dict]) -> list[dict]: @@ -1849,6 +4033,115 @@ def _normalize_messages(messages: list[dict]) -> list[dict]: return merged +def _get_engine_tokenizer(engine) -> object | None: + """ + Return the tokenizer backing ``engine``, if exposed. + + Different engine classes store the tokenizer under different attributes. + We try the common ones and return ``None`` if nothing matches, so that + optional features like constrained decoding can degrade gracefully. + """ + for attr in ("_tokenizer", "tokenizer", "_processor", "processor"): + tok = getattr(engine, attr, None) + if tok is not None: + return tok + return None + + +@app.post( + "/v1/responses", + dependencies=[Depends(verify_api_key), Depends(check_rate_limit)], +) +async def create_response(request: ResponsesRequest, raw_request: Request): + """Create a Responses API response.""" + if request.stream: + return StreamingResponse( + _disconnect_guard(_stream_responses_request(request), raw_request), + media_type="text/event-stream", + ) + + response_object, _persisted_messages = await _run_responses_request( + request, raw_request + ) + if response_object is None: + return Response(status_code=499) + + return response_object + + +def _get_forced_tool_name(tool_choice) -> str | None: + """Extract forced tool name from tool_choice, if any. + + Returns the function name when tool_choice is a dict like + {"type": "function", "function": {"name": "X"}}, or None otherwise. + """ + if not isinstance(tool_choice, dict): + return None + if tool_choice.get("type") != "function": + return None + func = tool_choice.get("function") + if isinstance(func, dict): + return func.get("name") + return None + + +def _apply_forced_tool_choice(tool_choice, tools, messages, chat_kwargs=None): + """Apply forced tool_choice by filtering tools and injecting instructions. + + Handles: + - tool_choice={"type":"function","function":{"name":"X"}} -> filter + instruct + - tool_choice="required" -> instruct model to call at least one tool + + Args: + tool_choice: The tool_choice value from the request + tools: List of converted tools for the template + messages: The message list (will be copied if modified) + chat_kwargs: Optional dict to modify (e.g. disable thinking) + + Returns: + Tuple of (tools, messages) - potentially filtered/modified + """ + if not tools: + return tools, messages + + forced_name = _get_forced_tool_name(tool_choice) + if forced_name: + # Filter tools to only the forced function + filtered = [t for t in tools if _tool_name(t) == forced_name] + if not filtered: + available = [_tool_name(t) for t in tools if _tool_name(t)] + raise ValueError( + f"tool_choice function '{forced_name}' not found in tools. " + f"Available: {available}" + ) + tools = filtered + instruction = ( + f"[IMPORTANT INSTRUCTION] You MUST call the `{forced_name}` function. " + f"Do NOT respond with plain text. Respond ONLY with a tool call to " + f"`{forced_name}`. This is mandatory." + ) + messages = _inject_json_instruction(messages, instruction) + # Disable thinking to prevent model from reasoning its way out + if chat_kwargs is not None: + chat_kwargs["enable_thinking"] = False + elif tool_choice == "required": + instruction = ( + "[IMPORTANT INSTRUCTION] You MUST call at least one of the available " + "tools. Do NOT respond with plain text only." + ) + messages = _inject_json_instruction(messages, instruction) + + return tools, messages + + +def _tool_name(tool: dict) -> str | None: + """Extract function name from a tool definition dict.""" + func = tool.get("function") + if isinstance(func, dict): + return func.get("name") + return None + + def _inject_json_instruction(messages: list, instruction: str) -> list: """ Inject JSON instruction into messages. @@ -1897,7 +4190,9 @@ def _convert_anthropic_stop_reason(openai_reason: str | None) -> str: return mapping.get(openai_reason or "", "end_turn") -@app.post("/v1/messages") +@app.post( + "/v1/messages", dependencies=[Depends(verify_api_key), Depends(check_rate_limit)] +) async def create_anthropic_message( request: Request, ): @@ -1909,7 +4204,6 @@ async def create_anthropic_message( Supports both streaming and non-streaming modes. """ - engine = get_engine() tracker = _metrics.track_inference("anthropic_messages", stream=False) # Parse the raw body to handle Anthropic request format. @@ -1928,6 +4222,7 @@ async def create_anthropic_message( anthropic_request = AnthropicRequest(**body) _validate_model_name(anthropic_request.model) + effective_max_tokens = _resolve_request_max_tokens(anthropic_request.max_tokens) # --- Detailed request logging --- n_msgs = len(anthropic_request.messages) @@ -1946,148 +4241,173 @@ async def create_anthropic_message( f"msgs={n_msgs} total_chars={total_chars} system_chars={sys_chars} " f"tools={n_tools}" ) - logger.info(f"[REQUEST] last user message preview: {last_user_preview!r}") + logger.info( + "[REQUEST] last user message preview: %s", + _sanitize_log_text(last_user_preview, limit=300), + ) # Convert Anthropic request -> OpenAI request openai_request = anthropic_to_openai(anthropic_request) - - if anthropic_request.stream: - return StreamingResponse( - _disconnect_guard( - _stream_anthropic_messages( - engine, - openai_request, - anthropic_request, - metrics_tracker=tracker, - ), - request, - ), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, - ) - - # Non-streaming: run inference through existing engine - messages, images, videos = extract_multimodal_content( - openai_request.messages, - preserve_native_format=engine.preserve_native_tool_format, + total_timeout, deadline = _start_request_budget(None) + engine = await _acquire_default_engine_for_request( + request, + total_timeout=total_timeout, + deadline=deadline, + ) + if engine is None: + return Response(status_code=499) + release_on_exit = True + prepared = _prepare_anthropic_invocation( + engine, + openai_request, + effective_max_tokens, ) - messages = _normalize_messages(messages) - - chat_kwargs = { - "max_tokens": openai_request.max_tokens or _default_max_tokens, - "temperature": openai_request.temperature, - "top_p": openai_request.top_p, - "top_k": openai_request.top_k or 0, - "min_p": openai_request.min_p or 0.0, - "presence_penalty": openai_request.presence_penalty or 0.0, - "repetition_penalty": openai_request.repetition_penalty or 1.0, - } - - if openai_request.tools and openai_request.tool_choice != "none": - chat_kwargs["tools"] = convert_tools_for_template(openai_request.tools) - - start_time = time.perf_counter() - timeout = _default_timeout try: - output = await _wait_with_disconnect( - engine.chat(messages=messages, **chat_kwargs), - request, - timeout=timeout, - ) - except HTTPException as exc: - tracker.finish(result=_metrics_result_from_status(exc.status_code)) - raise - if output is None: - tracker.finish(result="client_closed") - return Response(status_code=499) # Client closed request + if anthropic_request.stream: + anthropic_terminal = ( + f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'})}\n\n" + ) + response = StreamingResponse( + _disconnect_guard( + _ensure_sse_terminal( + _stream_anthropic_messages( + engine, + openai_request, + anthropic_request, + prepared, + metrics_tracker=tracker, + ), + anthropic_terminal, + ), + request, + cleanup=_release_default_engine, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + release_on_exit = False + return response - elapsed = time.perf_counter() - start_time - tokens_per_sec = output.completion_tokens / elapsed if elapsed > 0 else 0 - logger.info( - f"Anthropic messages: {output.completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" - ) + start_time = time.perf_counter() + try: + output = await _wait_with_disconnect( + engine.chat(messages=prepared.messages, **prepared.chat_kwargs), + request, + timeout=_remaining_request_timeout(total_timeout, deadline), + timeout_detail_seconds=total_timeout, + ) + except HTTPException as exc: + tracker.finish(result=_metrics_result_from_status(exc.status_code)) + raise + if output is None: + tracker.finish(result="client_closed") + return Response(status_code=499) # Client closed request - # Parse tool calls (skip when no tools to avoid misinterpreting output) - if openai_request.tools: - cleaned_text, tool_calls = _parse_tool_calls_with_parser( - output.text, openai_request + elapsed = time.perf_counter() - start_time + tokens_per_sec = output.completion_tokens / elapsed if elapsed > 0 else 0 + logger.info( + f"Anthropic messages: {output.completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" ) - else: - cleaned_text, tool_calls = output.text, None - # Extract reasoning if parser is configured - reasoning_text = None - if _reasoning_parser and not tool_calls: - text_to_parse = cleaned_text or output.text - reasoning_text, cleaned_text = _reasoning_parser.extract_reasoning( - text_to_parse + reasoning_text, cleaned_text, tool_calls = _extract_reasoning_and_tool_calls( + output.text, + openai_request, + allow_reasoning=prepared.json_logits_processor is None, + engine=engine, ) - # Clean output text - final_content = None - if cleaned_text: - final_content = clean_output_text(cleaned_text) + if prepared.response_format and not tool_calls: + json_input = cleaned_text or output.text + _, parsed_json, is_valid, error = parse_json_output( + json_input, prepared.response_format + ) + if parsed_json is not None: + cleaned_text = json.dumps(parsed_json) + if not is_valid: + if prepared.json_logits_processor is not None: + logger.error( + "Constrained decoding produced invalid JSON on Anthropic endpoint: %s", + error, + ) + else: + logger.warning( + "JSON validation failed on Anthropic endpoint: %s", error + ) + + # Clean output text + final_content = None + if cleaned_text: + final_content = clean_output_text(cleaned_text) - # Build Anthropic content blocks directly (with thinking support) - content_blocks = [] + # Determine finish reason + finish_reason = "tool_calls" if tool_calls else output.finish_reason - if reasoning_text: - content_blocks.append( - AnthropicResponseContentBlock(type="thinking", thinking=reasoning_text) - ) + # Build Anthropic content blocks directly (with thinking support) + content_blocks = [] - if final_content: - content_blocks.append( - AnthropicResponseContentBlock(type="text", text=final_content) - ) + if reasoning_text: + content_blocks.append( + AnthropicResponseContentBlock(type="thinking", thinking=reasoning_text) + ) - if tool_calls: - for tc in tool_calls: - try: - tool_input = json.loads(tc.function.arguments) - except (json.JSONDecodeError, AttributeError): - tool_input = {} + if final_content: content_blocks.append( - AnthropicResponseContentBlock( - type="tool_use", - id=tc.id, - name=tc.function.name, - input=tool_input, - ) + AnthropicResponseContentBlock(type="text", text=final_content) ) - if not content_blocks: - content_blocks.append(AnthropicResponseContentBlock(type="text", text="")) + if tool_calls: + for tc in tool_calls: + try: + tool_input = json.loads(tc.function.arguments) + except (json.JSONDecodeError, AttributeError): + tool_input = {} + content_blocks.append( + AnthropicResponseContentBlock( + type="tool_use", + id=tc.id, + name=tc.function.name, + input=tool_input, + ) + ) - stop_reason = _convert_anthropic_stop_reason( - "tool_calls" if tool_calls else output.finish_reason - ) + if not content_blocks: + content_blocks.append(AnthropicResponseContentBlock(type="text", text="")) - anthropic_response = AnthropicResponse( - model=_model_name, - content=content_blocks, - stop_reason=stop_reason, - usage=AnthropicUsage( - input_tokens=output.prompt_tokens, - output_tokens=output.completion_tokens, - ), - ) - tracker.finish( - result="success", - prompt_tokens=output.prompt_tokens, - completion_tokens=output.completion_tokens, - ) - return Response( - content=anthropic_response.model_dump_json(exclude_none=True), - media_type="application/json", - ) + stop_reason = _convert_anthropic_stop_reason( + "tool_calls" if tool_calls else output.finish_reason + ) + + anthropic_response = AnthropicResponse( + model=_model_name, + content=content_blocks, + stop_reason=stop_reason, + usage=AnthropicUsage( + input_tokens=output.prompt_tokens, + output_tokens=output.completion_tokens, + ), + ) + tracker.finish( + result="success", + prompt_tokens=output.prompt_tokens, + completion_tokens=output.completion_tokens, + ) + return Response( + content=anthropic_response.model_dump_json(exclude_none=True), + media_type="application/json", + ) + finally: + if release_on_exit: + await _release_default_engine() -@app.post("/v1/messages/count_tokens") +@app.post( + "/v1/messages/count_tokens", + dependencies=[Depends(verify_api_key), Depends(check_rate_limit)], +) async def count_anthropic_tokens(request: Request): """ Count tokens for an Anthropic Messages API request. @@ -2098,63 +4418,77 @@ async def count_anthropic_tokens(request: Request): from Claude Code don't include max_tokens. """ body = await request.json() + request_model = body.get("model") + if isinstance(request_model, str) and request_model: + _validate_model_name(request_model) + total_timeout, deadline = _start_request_budget(None) + engine = await _acquire_default_engine_for_request( + request, + total_timeout=total_timeout, + deadline=deadline, + count_activity=False, + ) + if engine is None: + return Response(status_code=499) - engine = get_engine() tokenizer = engine.tokenizer total_tokens = 0 - # System message - system = body.get("system", "") - if isinstance(system, str) and system: - total_tokens += len(tokenizer.encode(system)) - elif isinstance(system, list): - for block in system: - if isinstance(block, dict): - text = block.get("text", "") - if text: - total_tokens += len(tokenizer.encode(text)) - - # Messages - for msg in body.get("messages", []): - content = msg.get("content", "") - if isinstance(content, str): - if content: - total_tokens += len(tokenizer.encode(content)) - elif isinstance(content, list): - for block in content: + try: + # System message + system = body.get("system", "") + if isinstance(system, str) and system: + total_tokens += len(tokenizer.encode(system)) + elif isinstance(system, list): + for block in system: if isinstance(block, dict): text = block.get("text", "") if text: total_tokens += len(tokenizer.encode(text)) - # tool_use input - if block.get("input"): - total_tokens += len( - tokenizer.encode(json.dumps(block["input"])) - ) - # tool_result content - sub_content = block.get("content", "") - if isinstance(sub_content, str) and sub_content: - total_tokens += len(tokenizer.encode(sub_content)) - elif isinstance(sub_content, list): - for item in sub_content: - if isinstance(item, dict): - item_text = item.get("text", "") - if item_text: - total_tokens += len(tokenizer.encode(item_text)) - - # Tools - for tool in body.get("tools", []): - name = tool.get("name", "") - if name: - total_tokens += len(tokenizer.encode(name)) - desc = tool.get("description", "") - if desc: - total_tokens += len(tokenizer.encode(desc)) - if tool.get("input_schema"): - total_tokens += len(tokenizer.encode(json.dumps(tool["input_schema"]))) - - return {"input_tokens": total_tokens} + + # Messages + for msg in body.get("messages", []): + content = msg.get("content", "") + if isinstance(content, str): + if content: + total_tokens += len(tokenizer.encode(content)) + elif isinstance(content, list): + for block in content: + if isinstance(block, dict): + text = block.get("text", "") + if text: + total_tokens += len(tokenizer.encode(text)) + # tool_use input + if block.get("input"): + total_tokens += len( + tokenizer.encode(json.dumps(block["input"])) + ) + # tool_result content + sub_content = block.get("content", "") + if isinstance(sub_content, str) and sub_content: + total_tokens += len(tokenizer.encode(sub_content)) + elif isinstance(sub_content, list): + for item in sub_content: + if isinstance(item, dict): + item_text = item.get("text", "") + if item_text: + total_tokens += len(tokenizer.encode(item_text)) + + # Tools + for tool in body.get("tools", []): + name = tool.get("name", "") + if name: + total_tokens += len(tokenizer.encode(name)) + desc = tool.get("description", "") + if desc: + total_tokens += len(tokenizer.encode(desc)) + if tool.get("input_schema"): + total_tokens += len(tokenizer.encode(json.dumps(tool["input_schema"]))) + + return {"input_tokens": total_tokens} + finally: + await _release_default_engine(count_activity=False) def _emit_content_pieces( @@ -2214,6 +4548,7 @@ async def _stream_anthropic_messages( engine: BaseEngine, openai_request: ChatCompletionRequest, anthropic_request: AnthropicRequest, + prepared: PreparedChatInvocation, metrics_tracker=None, ) -> AsyncIterator[str]: """ @@ -2232,25 +4567,8 @@ async def _stream_anthropic_messages( result_label = "success" prompt_tokens = 0 - # Extract messages for engine - messages, images, videos = extract_multimodal_content( - openai_request.messages, - preserve_native_format=engine.preserve_native_tool_format, - ) - messages = _normalize_messages(messages) - - chat_kwargs = { - "max_tokens": openai_request.max_tokens or _default_max_tokens, - "temperature": openai_request.temperature, - "top_p": openai_request.top_p, - "top_k": openai_request.top_k or 0, - "min_p": openai_request.min_p or 0.0, - "presence_penalty": openai_request.presence_penalty or 0.0, - "repetition_penalty": openai_request.repetition_penalty or 1.0, - } - - if openai_request.tools and openai_request.tool_choice != "none": - chat_kwargs["tools"] = convert_tools_for_template(openai_request.tools) + messages = prepared.messages + chat_kwargs = dict(prepared.chat_kwargs) # Emit message_start message_start = { @@ -2271,7 +4589,9 @@ async def _stream_anthropic_messages( } yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n" - use_reasoning = _reasoning_parser is not None + use_reasoning = _reasoning_parser is not None and not chat_kwargs.get( + "logits_processors" + ) if use_reasoning: _reasoning_parser.reset_state() @@ -2294,24 +4614,10 @@ async def _stream_anthropic_messages( # Tool call streaming suppression — prevents raw tool markup from leaking # as text_delta events. Mirrors the OpenAI streaming path logic. - global _tool_parser_instance tool_parser = None tool_accumulated_text = "" tool_markup_possible = False - tool_choice = getattr(openai_request, "tool_choice", None) - if _enable_auto_tool_choice and _tool_call_parser and tool_choice != "none": - if _tool_parser_instance is None: - try: - parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) - tokenizer = None - if _engine is not None and hasattr(_engine, "_tokenizer"): - tokenizer = _engine._tokenizer - _tool_parser_instance = parser_cls(tokenizer) - except Exception: - pass - if _tool_parser_instance is not None: - tool_parser = _tool_parser_instance - tool_parser.reset() + tool_parser = _get_streaming_tool_parser(openai_request, engine) try: async for output in engine.stream_chat(messages=messages, **chat_kwargs): @@ -2341,7 +4647,12 @@ async def _stream_anthropic_messages( # Filter tool call markup during streaming if tool_parser and content_to_emit: - if not tool_markup_possible and "<" not in content_to_emit: + if ( + not tool_markup_possible + and not _streaming_tool_markup_possible( + tool_accumulated_text + content_to_emit + ) + ): tool_accumulated_text += content_to_emit else: if not tool_markup_possible: @@ -2386,7 +4697,12 @@ async def _stream_anthropic_messages( # Filter tool call markup during streaming if tool_parser and content_to_emit: - if not tool_markup_possible and "<" not in content_to_emit: + if ( + not tool_markup_possible + and not _streaming_tool_markup_possible( + tool_accumulated_text + content_to_emit + ) + ): tool_accumulated_text += content_to_emit else: if not tool_markup_possible: @@ -2428,7 +4744,11 @@ async def _stream_anthropic_messages( text_block_started = True # Check for tool calls in accumulated text - _, tool_calls = _parse_tool_calls_with_parser(accumulated_text, openai_request) + _, tool_calls = _parse_tool_calls_with_parser( + accumulated_text, + openai_request, + engine=engine, + ) # Close text block if text_block_started: @@ -2495,6 +4815,7 @@ async def stream_completion( engine: BaseEngine, prompt: str, request: CompletionRequest, + max_tokens: int, repetition_penalty: float | None = None, metrics_tracker=None, ) -> AsyncIterator[str]: @@ -2504,7 +4825,7 @@ async def stream_completion( completion_tokens = 0 generate_kwargs = { "prompt": prompt, - "max_tokens": request.max_tokens or _default_max_tokens, + "max_tokens": max_tokens, "temperature": _resolve_temperature(request.temperature), "top_p": _resolve_top_p(request.top_p), "top_k": request.top_k or 0, @@ -2551,8 +4872,6 @@ async def stream_completion( if output.finished: data["usage"] = get_usage(output).model_dump() yield f"data: {json.dumps(data)}\n\n" - - yield "data: [DONE]\n\n" except HTTPException as exc: result = _metrics_result_from_status(exc.status_code) raise @@ -2563,6 +4882,7 @@ async def stream_completion( result = "error" raise finally: + yield "data: [DONE]\n\n" if metrics_tracker is not None: metrics_tracker.finish( result=result, @@ -2617,28 +4937,27 @@ async def stream_chat_completion( completion_tokens = 0 last_output = None + # Response-format streaming filter — strip markdown code fences from + # content when client asked for JSON. Non-streaming path strips fences + # via ``parse_json_output``; without this, streaming clients see + # ``"```json{...}```"`` instead of ``"{...}"`` for models that wrap + # their structured output in markdown (e.g. Gemma 4). + fence_stripper: StreamingJsonFenceStripper | None = None + _rf = getattr(request, "response_format", None) + _rf_type = None + if _rf is not None: + _rf_type = getattr(_rf, "type", None) + if _rf_type is None and isinstance(_rf, dict): + _rf_type = _rf.get("type") + if _rf_type in ("json_object", "json_schema"): + fence_stripper = StreamingJsonFenceStripper() + # Tool call streaming state - global _tool_parser_instance tool_parser = None tool_accumulated_text = "" tool_calls_detected = False - tool_markup_possible = False # Fast path: skip parsing until '<' seen - tool_choice = getattr(request, "tool_choice", None) - if _enable_auto_tool_choice and _tool_call_parser and tool_choice != "none": - # Initialize parser if needed (same as _parse_tool_calls_with_parser) - if _tool_parser_instance is None: - try: - parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) - tokenizer = None - if _engine is not None and hasattr(_engine, "_tokenizer"): - tokenizer = _engine._tokenizer - _tool_parser_instance = parser_cls(tokenizer) - logger.info(f"Initialized tool call parser: {_tool_call_parser}") - except Exception as e: - logger.warning(f"Failed to init tool parser for streaming: {e}") - if _tool_parser_instance is not None: - tool_parser = _tool_parser_instance - tool_parser.reset() + tool_markup_possible = False # Fast path: skip parsing until markers appear + tool_parser = _get_streaming_tool_parser(request, engine) try: # Stream content @@ -2689,7 +5008,12 @@ async def stream_chat_completion( # Tool call parsing on content portion if tool_parser and content: - if not tool_markup_possible and "<" not in content: + if ( + not tool_markup_possible + and not _streaming_tool_markup_possible( + tool_accumulated_text + content + ) + ): tool_accumulated_text += content # Suppress whitespace-only content when tools are active; # avoids emitting stray newlines before tool call XML. @@ -2765,6 +5089,14 @@ async def stream_chat_completion( if content: content = _TOOL_MARKUP_PATTERN.sub("", content) + # Strip markdown code fences when response_format is set. + if fence_stripper is not None and not tool_calls_detected: + content = fence_stripper.feed(content) if content else "" + if output.finished: + flush = fence_stripper.finalize() + if flush: + content = content + flush + chunk = ChatCompletionChunk( id=response_id, model=_model_name, @@ -2799,10 +5131,16 @@ async def stream_chat_completion( # Tool call streaming parsing if tool_parser and delta_text: - # Fast path: skip full parsing until '<' is seen in the stream, - # which could start tool markup (e.g. ). This avoids - # per-token string scanning on the growing accumulated text. - if not tool_markup_possible and "<" not in delta_text: + # Fast path: skip full parsing until likely tool markup appears. + # This preserves the cheap path for ordinary text while still + # allowing generic streaming tool parsing when no explicit + # parser flags are configured. + if ( + not tool_markup_possible + and not _streaming_tool_markup_possible( + tool_accumulated_text + delta_text + ) + ): tool_accumulated_text += delta_text # No tool markup yet, fall through to normal chunk emission else: @@ -2858,6 +5196,14 @@ async def stream_chat_completion( if content: content = _TOOL_MARKUP_PATTERN.sub("", content) + # Strip markdown code fences when response_format is set. + if fence_stripper is not None and not tool_calls_detected: + content = fence_stripper.feed(content) if content else "" + if output.finished: + flush = fence_stripper.finalize() + if flush: + content = content + flush + chunk = ChatCompletionChunk( id=response_id, model=_model_name, @@ -2883,11 +5229,7 @@ async def stream_chat_completion( tool_parser and tool_accumulated_text and not tool_calls_detected - and ( - "" in tool_accumulated_text - or "<|tool_call>" in tool_accumulated_text - or " 0 else 0 @@ -2975,13 +5350,19 @@ async def init_mcp(config_path: str): global _mcp_manager, _mcp_executor try: - from vllm_mlx.mcp import MCPClientManager, ToolExecutor, load_mcp_config + from vllm_mlx.mcp import ( + MCPClientManager, + ToolExecutor, + ToolSandbox, + load_mcp_config, + ) config = load_mcp_config(config_path) _mcp_manager = MCPClientManager(config) await _mcp_manager.start() - _mcp_executor = ToolExecutor(_mcp_manager) + sandbox = ToolSandbox(allowed_high_risk_tools=config.allowed_high_risk_tools) + _mcp_executor = ToolExecutor(_mcp_manager, sandbox=sandbox) logger.info(f"MCP initialized with {len(_mcp_manager.get_all_tools())} tools") @@ -2989,10 +5370,53 @@ async def init_mcp(config_path: str): logger.error("MCP SDK not installed. Install with: pip install mcp") raise except Exception as e: - logger.error(f"Failed to initialize MCP: {e}") + logger.error("Failed to initialize MCP: %s", _sanitize_log_text(e, limit=500)) raise +# ============================================================================= +# TCP Keepalive +# ============================================================================= + + +def _make_keepalive_http_protocol(idle=10, interval=5, count=3): + """Create a uvicorn HTTP protocol class with aggressive TCP keepalive. + + When a client abruptly disconnects (power-off, network loss), the server + TCP stack won't notice for ~2 hours (default keepalive). With aggressive + keepalive (idle=10s, interval=5s, count=3), dead connections are detected + in ~25 seconds, letting ``_wait_with_disconnect()`` abort the request and + stop wasting GPU cycles on tokens nobody will receive. + """ + from uvicorn.protocols.http.h11_impl import H11Protocol + + _Base = H11Protocol + + class _KeepaliveProtocol(_Base): + def connection_made(self, transport): + super().connection_made(transport) + sock = transport.get_extra_info("socket") + if sock is None: + return + try: + sock.setsockopt(_socket.SOL_SOCKET, _socket.SO_KEEPALIVE, 1) + # macOS: TCP_KEEPALIVE (idle time), Linux: TCP_KEEPIDLE + if hasattr(_socket, "TCP_KEEPALIVE"): + sock.setsockopt(_socket.IPPROTO_TCP, _socket.TCP_KEEPALIVE, idle) + elif hasattr(_socket, "TCP_KEEPIDLE"): + sock.setsockopt(_socket.IPPROTO_TCP, _socket.TCP_KEEPIDLE, idle) + if hasattr(_socket, "TCP_KEEPINTVL"): + sock.setsockopt( + _socket.IPPROTO_TCP, _socket.TCP_KEEPINTVL, interval + ) + if hasattr(_socket, "TCP_KEEPCNT"): + sock.setsockopt(_socket.IPPROTO_TCP, _socket.TCP_KEEPCNT, count) + except OSError: + pass # best-effort; some platforms may not support all options + + return _KeepaliveProtocol + + # ============================================================================= # Main Entry Point # ============================================================================= @@ -3000,6 +5424,106 @@ async def init_mcp(config_path: str): def main(): """Run the server.""" + parser = create_parser() + args = parser.parse_args() + + # Set global configuration + global _api_key, _default_timeout, _rate_limiter, _metrics_enabled + global _default_temperature, _default_top_p, _default_chat_template_kwargs + global _max_audio_upload_bytes, _max_tts_input_chars + _api_key = args.api_key + _default_timeout = args.timeout + _metrics_enabled = args.enable_metrics + _metrics.configure(enabled=args.enable_metrics) + if args.default_temperature is not None: + _default_temperature = args.default_temperature + if args.default_top_p is not None: + _default_top_p = args.default_top_p + _default_chat_template_kwargs = args.default_chat_template_kwargs + _max_audio_upload_bytes = args.max_audio_upload_mb * 1024 * 1024 + _max_tts_input_chars = args.max_tts_input_chars + + # Configure rate limiter + if args.rate_limit > 0: + _rate_limiter = RateLimiter(requests_per_minute=args.rate_limit, enabled=True) + logger.info( + f"Rate limiting enabled: {args.rate_limit} requests/minute per client" + ) + + # Security summary at startup + logger.info("=" * 60) + logger.info("SECURITY CONFIGURATION") + logger.info("=" * 60) + if _api_key: + logger.info(" Authentication: ENABLED (API key required)") + else: + logger.warning(" Authentication: DISABLED - Use --api-key to enable") + if args.rate_limit > 0: + logger.info(f" Rate limiting: ENABLED ({args.rate_limit} req/min)") + else: + logger.warning(" Rate limiting: DISABLED - Use --rate-limit to enable") + logger.info(f" Request timeout: {args.timeout}s") + if args.enable_metrics: + logger.info(" Metrics: ENABLED (/metrics, unauthenticated)") + else: + logger.info(" Metrics: DISABLED - Use --enable-metrics to expose /metrics") + if args.auto_unload_idle_seconds > 0: + logger.info( + " Idle auto-unload: ENABLED (%.0fs)", args.auto_unload_idle_seconds + ) + else: + logger.info(" Idle auto-unload: DISABLED") + if args.trust_remote_code: + logger.warning(" Remote code loading: ENABLED (--trust-remote-code)") + else: + logger.info(" Remote code loading: DISABLED (default)") + logger.info( + f" Audio upload limit: {args.max_audio_upload_mb} MiB, " + f"TTS input limit: {args.max_tts_input_chars} chars" + ) + logger.info("=" * 60) + + # Set MCP config for lifespan + if args.mcp_config: + os.environ["VLLM_MLX_MCP_CONFIG"] = args.mcp_config + + # Initialize reasoning parser if specified + if args.reasoning_parser: + global _reasoning_parser + from .reasoning import get_parser + + parser_cls = get_parser(args.reasoning_parser) + _reasoning_parser = parser_cls() + logger.info(f"Reasoning parser enabled: {args.reasoning_parser}") + + # Pre-load embedding model if specified + load_embedding_model(args.embedding_model, lock=True) + + # Load model before starting server + load_model( + args.model, + use_batching=args.continuous_batching, + max_tokens=args.max_tokens, + max_request_tokens=args.max_request_tokens, + force_mllm=args.mllm, + trust_remote_code=args.trust_remote_code, + auto_unload_idle_seconds=args.auto_unload_idle_seconds, + lazy_load_model=args.lazy_load_model, + ) + + # Start server with TCP keepalive for fast dead-client detection. + # Without this, abrupt client disconnects (power-off, network loss) take + # 2+ hours to detect via default TCP keepalive, wasting GPU cycles. + uvicorn.run( + app, + host=args.host, + port=args.port, + http=_make_keepalive_http_protocol(), + ) + + +def create_parser() -> argparse.ArgumentParser: + """Create the standalone server CLI parser.""" parser = argparse.ArgumentParser( description="vllm-mlx OpenAI-compatible server for LLM and MLLM inference", formatter_class=argparse.RawDescriptionHelpFormatter, @@ -3024,8 +5548,8 @@ def main(): parser.add_argument( "--host", type=str, - default="0.0.0.0", - help="Host to bind to", + default="127.0.0.1", + help="Host to bind to (default: localhost; use 0.0.0.0 to expose externally)", ) parser.add_argument( "--port", @@ -3038,6 +5562,11 @@ def main(): action="store_true", help="Force loading as MLLM (multimodal language model)", ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Allow HuggingFace remote code execution during model/tokenizer loading", + ) parser.add_argument( "--continuous-batching", action="store_true", @@ -3055,6 +5584,12 @@ def main(): default=32768, help="Default max tokens for generation", ) + parser.add_argument( + "--max-request-tokens", + type=int, + default=32768, + help="Maximum max_tokens accepted from API clients (default: 32768)", + ) parser.add_argument( "--api-key", type=str, @@ -3072,6 +5607,17 @@ def main(): action="store_true", help="Expose Prometheus metrics on /metrics (disabled by default)", ) + parser.add_argument( + "--auto-unload-idle-seconds", + type=float, + default=0.0, + help="Unload the main model after this many idle seconds (0 = disabled)", + ) + parser.add_argument( + "--lazy-load-model", + action="store_true", + help="Register the main model at startup but defer loading until first request", + ) parser.add_argument( "--rate-limit", type=int, @@ -3110,73 +5656,29 @@ def main(): default=None, help="Default top_p for generation when not specified in request", ) - - args = parser.parse_args() - - # Set global configuration - global _api_key, _default_timeout, _rate_limiter, _metrics_enabled - global _default_temperature, _default_top_p - _api_key = args.api_key - _default_timeout = args.timeout - _metrics_enabled = args.enable_metrics - _metrics.configure(enabled=args.enable_metrics) - if args.default_temperature is not None: - _default_temperature = args.default_temperature - if args.default_top_p is not None: - _default_top_p = args.default_top_p - - # Configure rate limiter - if args.rate_limit > 0: - _rate_limiter = RateLimiter(requests_per_minute=args.rate_limit, enabled=True) - logger.info( - f"Rate limiting enabled: {args.rate_limit} requests/minute per client" - ) - - # Security summary at startup - logger.info("=" * 60) - logger.info("SECURITY CONFIGURATION") - logger.info("=" * 60) - if _api_key: - logger.info(" Authentication: ENABLED (API key required)") - else: - logger.warning(" Authentication: DISABLED - Use --api-key to enable") - if args.rate_limit > 0: - logger.info(f" Rate limiting: ENABLED ({args.rate_limit} req/min)") - else: - logger.warning(" Rate limiting: DISABLED - Use --rate-limit to enable") - logger.info(f" Request timeout: {args.timeout}s") - if args.enable_metrics: - logger.info(" Metrics: ENABLED (/metrics, unauthenticated)") - else: - logger.info(" Metrics: DISABLED - Use --enable-metrics to expose /metrics") - logger.info("=" * 60) - - # Set MCP config for lifespan - if args.mcp_config: - os.environ["VLLM_MLX_MCP_CONFIG"] = args.mcp_config - - # Initialize reasoning parser if specified - if args.reasoning_parser: - global _reasoning_parser - from .reasoning import get_parser - - parser_cls = get_parser(args.reasoning_parser) - _reasoning_parser = parser_cls() - logger.info(f"Reasoning parser enabled: {args.reasoning_parser}") - - # Pre-load embedding model if specified - load_embedding_model(args.embedding_model, lock=True) - - # Load model before starting server - load_model( - args.model, - use_batching=args.continuous_batching, - max_tokens=args.max_tokens, - force_mllm=args.mllm, + parser.add_argument( + "--default-chat-template-kwargs", + type=make_json_object_arg_parser("--default-chat-template-kwargs"), + default=None, + help=( + "Default chat template kwargs to apply to all requests when request " + "chat_template_kwargs is omitted or empty; empty request kwargs use " + 'existing server defaults (JSON object, e.g. {"enable_thinking": false})' + ), ) - - # Start server - uvicorn.run(app, host=args.host, port=args.port) + parser.add_argument( + "--max-audio-upload-mb", + type=int, + default=DEFAULT_MAX_AUDIO_UPLOAD_MB, + help="Maximum size of uploaded audio files in MiB (default: 25)", + ) + parser.add_argument( + "--max-tts-input-chars", + type=int, + default=DEFAULT_MAX_TTS_INPUT_CHARS, + help="Maximum number of characters accepted by /v1/audio/speech (default: 4096)", + ) + return parser if __name__ == "__main__": diff --git a/vllm_mlx/specprefill.py b/vllm_mlx/specprefill.py index 9ebe4401a..0be77b8d0 100644 --- a/vllm_mlx/specprefill.py +++ b/vllm_mlx/specprefill.py @@ -155,32 +155,48 @@ def _unpatch_attention_capture(model, originals): _set_attn_module(model.layers[layer_idx], orig) -def _prefill_draft(model, tokens, cache, step_size=2048): +def _prefill_draft(model, tokens, cache, step_size=2048, cancel_check=None): """Prefill prompt tokens into cache. Returns logits from last token.""" prompt = mx.array(tokens) if not isinstance(tokens, mx.array) else tokens n = len(tokens) processed = 0 while n - processed > 1: + if cancel_check is not None: + cancel_check() chunk = min(step_size, n - processed - 1) model(prompt[processed : processed + chunk][None], cache=cache) mx.eval([c.state for c in cache]) processed += chunk mx.clear_cache() + if cancel_check is not None: + cancel_check() logits = model(prompt[processed:][None], cache=cache) mx.eval(logits) return logits -def _lookahead_decode(model, first_logits, cache, n_steps, temp=0.6, top_p=0.95): +def _lookahead_decode( + model, + first_logits, + cache, + n_steps, + temp=0.6, + top_p=0.95, + cancel_check=None, +): """Run n_steps autoregressive decode, returning generated token ids. Query vectors are captured by the monkey-patched attention layers. """ sampler = make_sampler(temp=temp, top_p=top_p) + if cancel_check is not None: + cancel_check() y = sampler(first_logits[:, -1, :]) mx.eval(y) generated = [y.item()] for _ in range(n_steps): + if cancel_check is not None: + cancel_check() logits = model(y.reshape(1, -1), cache=cache) y = sampler(logits[:, -1, :]) mx.eval(y) @@ -264,6 +280,7 @@ def score_tokens( top_p=0.95, prefill_step_size=2048, query_extractor=None, + cancel_check=None, ): """Score token importance using attention-based analysis on a draft model. @@ -320,7 +337,13 @@ def score_tokens( # Phase 1: Prefill cache = make_prompt_cache(model) - logits = _prefill_draft(model, tokens, cache, step_size=prefill_step_size) + logits = _prefill_draft( + model, + tokens, + cache, + step_size=prefill_step_size, + cancel_check=cancel_check, + ) # Phase 2: Lookahead decode with query capture query_buffer = [[] for _ in range(n_attn_layers)] @@ -328,7 +351,15 @@ def score_tokens( model, query_buffer, query_extractor=query_extractor ) try: - _lookahead_decode(model, logits, cache, n_lookahead, temp=temp, top_p=top_p) + _lookahead_decode( + model, + logits, + cache, + n_lookahead, + temp=temp, + top_p=top_p, + cancel_check=cancel_check, + ) mx.eval(query_buffer) finally: _unpatch_attention_capture(model, patches) @@ -338,6 +369,8 @@ def score_tokens( # compacted for Nemotron-H where only M/* layers have cache entries) layer_to_cache = _build_layer_to_cache_map(model) attn_caches = [cache[layer_to_cache[i]] for i in attn_indices] + if cancel_check is not None: + cancel_check() importance = _compute_importance( query_buffer, attn_caches, @@ -606,7 +639,13 @@ def _build_layer_to_cache_map(model): def sparse_prefill( - model, tokens, selected_indices, cache, step_size=2048, position_offset=0 + model, + tokens, + selected_indices, + cache, + step_size=2048, + position_offset=0, + cancel_check=None, ): """Prefill the model cache with selected tokens at their original positions. @@ -692,6 +731,8 @@ def sparse_prefill( processed = 0 while n - processed > 1: + if cancel_check is not None: + cancel_check() chunk = min(step_size, n - processed - 1) model(prompt[processed : processed + chunk][None], cache=cache) mx.eval([c.state for c in cache]) @@ -699,6 +740,8 @@ def sparse_prefill( mx.clear_cache() # Last token → logits + if cancel_check is not None: + cancel_check() logits = model(prompt[processed:][None], cache=cache) mx.eval(logits) diff --git a/vllm_mlx/ssd_cache.py b/vllm_mlx/ssd_cache.py new file mode 100644 index 000000000..e61aba33b --- /dev/null +++ b/vllm_mlx/ssd_cache.py @@ -0,0 +1,1003 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +SSD KV cache tiering for vllm-mlx. + +This module provides a cold-tier disk cache that sits behind +MemoryAwarePrefixCache. Evicted entries spill to NVMe instead of being +discarded, and cold-tier fetches reload from disk asynchronously with +RAM budget reservation before the read completes. + +Key design: +- SQLite for atomic metadata index (no mutable JSON) +- Async writer thread for non-blocking spills +- Per-layer serializer interface for hybrid cache types +- Atomic temp-file + rename writes for crash consistency +- Metrics exposed from day one +""" + +from __future__ import annotations + +import array as _array +import hashlib +import json +import logging +import os +import queue +import sqlite3 +import threading +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + +import numpy as np + +logger = logging.getLogger(__name__) + +_BYTES_PER_MB = 1024 * 1024 +_BYTES_PER_GB = 1024 * 1024 * 1024 + + +@dataclass(frozen=True) +class SSDCacheConfig: + """Configuration for SSD cache tier. + + Attributes: + cache_dir: Directory for SSD cache files. None = auto-detect + (~/.cache/vllm-mlx/ssd_cache/{model}/). + max_size_gb: Maximum total size of SSD cache in GB. + max_entries: Maximum number of entries in SSD cache. + file_permissions: Unix permission bits for cache data files. + dir_permissions: Unix permission bits for cache directories. + spill_queue_size: Max pending spill operations before dropping. + retention_seconds: Optional max age for cache entries (None = no expiry). + """ + + cache_dir: str | None = None + max_size_gb: float = 10.0 + max_entries: int = 10000 + file_permissions: int = 0o600 + dir_permissions: int = 0o700 + spill_queue_size: int = 64 + retention_seconds: int | None = None + + def __post_init__(self) -> None: + if self.max_size_gb <= 0: + raise ValueError(f"max_size_gb must be > 0, got {self.max_size_gb}") + if self.max_entries < 1: + raise ValueError(f"max_entries must be >= 1, got {self.max_entries}") + if self.spill_queue_size < 1: + raise ValueError( + f"spill_queue_size must be >= 1, got {self.spill_queue_size}" + ) + + @property + def max_size_bytes(self) -> int: + """Maximum cache size in bytes.""" + return int(self.max_size_gb * _BYTES_PER_GB) + + +@dataclass +class SSDCacheStats: + """Statistics for SSD cache tier — exposed from day one. + + Attributes: + spill_count: Number of entries spilled to SSD. + spill_bytes: Total bytes written to SSD. + ssd_hits: Number of successful SSD cache lookups. + ssd_misses: Number of SSD cache lookup misses. + reload_latency_sum: Sum of reload latencies in seconds. + reload_bytes: Total bytes read from SSD. + promotion_failures: Number of failed promotions (RAM budget exhausted). + """ + + spill_count: int = 0 + spill_bytes: int = 0 + ssd_hits: int = 0 + ssd_misses: int = 0 + reload_latency_sum: float = 0.0 + reload_bytes: int = 0 + promotion_failures: int = 0 + + def to_dict(self) -> dict: + total_lookups = self.ssd_hits + self.ssd_misses + hit_rate = self.ssd_hits / total_lookups if total_lookups > 0 else 0.0 + avg_latency_ms = ( + (self.reload_latency_sum / self.ssd_hits * 1000) + if self.ssd_hits > 0 + else 0.0 + ) + return { + "spill_count": self.spill_count, + "spill_bytes": self.spill_bytes, + "ssd_hits": self.ssd_hits, + "ssd_misses": self.ssd_misses, + "ssd_hit_rate": round(hit_rate, 4), + "reload_latency_sum_s": round(self.reload_latency_sum, 4), + "avg_reload_latency_ms": round(avg_latency_ms, 2), + "reload_bytes": self.reload_bytes, + "promotion_failures": self.promotion_failures, + } + + +def _tokens_to_blob(tokens: tuple[int, ...]) -> bytes: + """Serialize token tuple to a compact binary blob for SQLite storage. + + Uses the full token sequence as a binary blob for prefix matching. + """ + arr = _array.array("i", tokens) + return arr.tobytes() + + +def _blob_to_tokens(blob: bytes) -> tuple[int, ...]: + """Deserialize binary blob back to token tuple.""" + arr = _array.array("i") + arr.frombytes(blob) + return tuple(arr) + + +def _tokens_hash(tokens: tuple[int, ...]) -> str: + """Compute SHA-256 hex digest of a token sequence for use as primary key.""" + return hashlib.sha256(_tokens_to_blob(tokens)).hexdigest() + + +class SSDIndex: + """SQLite-backed index for SSD cache entries. + + Uses SQLite for atomic metadata operations instead of mutable JSON. + The token sequence is stored as a binary blob for prefix-searchable + representation. The primary key is a SHA-256 hash of the token sequence. + + Thread safety: All operations are serialized through a threading.Lock. + The SQLite connection uses WAL mode for concurrent read/write safety. + """ + + _SCHEMA_VERSION = 1 + + def __init__(self, cache_dir: str) -> None: + self._cache_dir = cache_dir + self._db_lock = threading.Lock() + db_path = os.path.join(cache_dir, "index.db") + self._conn = sqlite3.connect(db_path, check_same_thread=False) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA synchronous=NORMAL") + self._conn.row_factory = sqlite3.Row + self._create_tables() + + def _create_tables(self) -> None: + schema_sql = """ + CREATE TABLE IF NOT EXISTS schema_version ( + version INTEGER NOT NULL + ); + + CREATE TABLE IF NOT EXISTS entries ( + token_hash TEXT PRIMARY KEY, + tokens_blob BLOB NOT NULL, + num_tokens INTEGER NOT NULL, + file_path TEXT NOT NULL, + memory_bytes INTEGER NOT NULL, + created_at REAL NOT NULL, + accessed_at REAL NOT NULL + ); + + CREATE INDEX IF NOT EXISTS idx_entries_accessed + ON entries(accessed_at); + + CREATE INDEX IF NOT EXISTS idx_entries_num_tokens + ON entries(num_tokens); + """ + self._conn.executescript(schema_sql) + # Insert schema version if not present + cur = self._conn.execute("SELECT COUNT(*) FROM schema_version") + if cur.fetchone()[0] == 0: + self._conn.execute( + "INSERT INTO schema_version (version) VALUES (?)", + (self._SCHEMA_VERSION,), + ) + self._conn.commit() + + def insert_entry( + self, + tokens_key: tuple[int, ...], + file_path: str, + memory_bytes: int, + num_tokens: int, + ) -> None: + """Insert or replace a cache entry in the index.""" + now = time.time() + token_hash = _tokens_hash(tokens_key) + tokens_blob = _tokens_to_blob(tokens_key) + with self._db_lock: + self._conn.execute( + """ + INSERT OR REPLACE INTO entries + (token_hash, tokens_blob, num_tokens, file_path, + memory_bytes, created_at, accessed_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + token_hash, + tokens_blob, + num_tokens, + file_path, + memory_bytes, + now, + now, + ), + ) + self._conn.commit() + + def lookup_exact(self, tokens_key: tuple[int, ...]) -> dict | None: + """Look up an exact token sequence. Returns dict or None.""" + token_hash = _tokens_hash(tokens_key) + with self._db_lock: + cur = self._conn.execute( + "SELECT file_path, memory_bytes, num_tokens FROM entries WHERE token_hash = ?", + (token_hash,), + ) + row = cur.fetchone() + if row is None: + return None + return { + "file_path": row["file_path"], + "memory_bytes": row["memory_bytes"], + "num_tokens": row["num_tokens"], + } + + def lookup_prefix(self, query_tokens: tuple[int, ...]) -> list[dict]: + """Find entries whose token sequence is a prefix of query_tokens. + + Scans entries with num_tokens <= len(query_tokens) and compares the + full stored token blob against the corresponding prefix of query_tokens. + + Returns list of dicts sorted by num_tokens descending (longest prefix first). + """ + query_len = len(query_tokens) + query_blob = _tokens_to_blob(query_tokens) + + with self._db_lock: + cur = self._conn.execute( + "SELECT token_hash, tokens_blob, num_tokens, file_path, memory_bytes " + "FROM entries WHERE num_tokens <= ? ORDER BY num_tokens DESC", + (query_len,), + ) + rows = cur.fetchall() + + results = [] + for row in rows: + stored_blob = row["tokens_blob"] + n = row["num_tokens"] + prefix_blob = query_blob[: n * 4] + if stored_blob == prefix_blob: + results.append( + { + "token_hash": row["token_hash"], + "file_path": row["file_path"], + "memory_bytes": row["memory_bytes"], + "num_tokens": n, + } + ) + return results + + def delete_entry(self, tokens_key: tuple[int, ...]) -> None: + """Delete an entry by token sequence.""" + token_hash = _tokens_hash(tokens_key) + with self._db_lock: + self._conn.execute( + "DELETE FROM entries WHERE token_hash = ?", (token_hash,) + ) + self._conn.commit() + + def get_lru(self, limit: int = 10) -> list[dict]: + """Get the least recently used entries, ordered oldest first.""" + with self._db_lock: + cur = self._conn.execute( + "SELECT token_hash, tokens_blob, num_tokens, file_path, memory_bytes " + "FROM entries ORDER BY accessed_at ASC LIMIT ?", + (limit,), + ) + rows = cur.fetchall() + results = [] + for row in rows: + results.append( + { + "token_hash": row["token_hash"], + "tokens_blob": row["tokens_blob"], + "file_path": row["file_path"], + "memory_bytes": row["memory_bytes"], + "num_tokens": row["num_tokens"], + } + ) + return results + + def get_total_bytes(self) -> int: + """Get total memory_bytes across all entries.""" + with self._db_lock: + cur = self._conn.execute( + "SELECT COALESCE(SUM(memory_bytes), 0) FROM entries" + ) + return cur.fetchone()[0] + + def get_entry_count(self) -> int: + """Get number of entries in the index.""" + with self._db_lock: + cur = self._conn.execute("SELECT COUNT(*) FROM entries") + return cur.fetchone()[0] + + def touch(self, tokens_key: tuple[int, ...]) -> None: + """Update accessed_at timestamp for an entry (marks as recently used).""" + token_hash = _tokens_hash(tokens_key) + with self._db_lock: + self._conn.execute( + "UPDATE entries SET accessed_at = ? WHERE token_hash = ?", + (time.time(), token_hash), + ) + self._conn.commit() + + def all_entries(self) -> list[dict]: + """Return all entries (for startup reconciliation).""" + with self._db_lock: + cur = self._conn.execute( + "SELECT token_hash, tokens_blob, num_tokens, file_path, memory_bytes " + "FROM entries ORDER BY accessed_at DESC" + ) + rows = cur.fetchall() + results = [] + for row in rows: + results.append( + { + "token_hash": row["token_hash"], + "tokens_blob": row["tokens_blob"], + "file_path": row["file_path"], + "memory_bytes": row["memory_bytes"], + "num_tokens": row["num_tokens"], + } + ) + return results + + def close(self) -> None: + """Close the SQLite connection.""" + with self._db_lock: + self._conn.close() + + +# Support matrix: maps cache type names to their serializer status +SERIALIZER_SUPPORT_MATRIX = { + "KVCache": "supported", + "RotatingKVCache": "supported", # Serialized as KVCache (keys/values/offset) + "ArraysCache": "supported", + "MambaCache": "supported", # Legacy name for ArraysCache + "_QuantizedCacheWrapper": "not_supported_spill_dequantized", +} + + +class LayerSerializer(ABC): + """Interface for per-layer cache serialization. + + Each implementation handles a specific cache type's serialization + to/from safetensors files with metadata. + """ + + @abstractmethod + def serialize_layer( + self, layer: Any, layer_idx: int, file_path: str + ) -> dict[str, Any]: + """Serialize a single cache layer to a file. + + Args: + layer: The cache layer object. + layer_idx: Index of this layer in the cache list. + file_path: Path to write the safetensors file. + + Returns: + Metadata dict with at least 'layer_type' key. + """ + ... + + @abstractmethod + def deserialize_layer(self, file_path: str, metadata: dict[str, Any]) -> dict: + """Deserialize a single cache layer from a file. + + Args: + file_path: Path to the safetensors file. + metadata: Metadata dict from serialize_layer. + + Returns: + Dict with layer state (keys/values/offset or state list). + """ + ... + + +class KVCacheSerializer(LayerSerializer): + """Serializer for KVCache and RotatingKVCache layers. + + Handles layers with .keys, .values, .offset attributes. + RotatingKVCache also has .max_size, .keep, .step, ._idx. + """ + + def serialize_layer( + self, layer: Any, layer_idx: int, file_path: str + ) -> dict[str, Any]: + from safetensors.numpy import save_file + + keys_np = np.array(layer.keys) + values_np = np.array(layer.values) + + tensors = { + f"layer_{layer_idx}_keys": keys_np, + f"layer_{layer_idx}_values": values_np, + } + save_file(tensors, file_path) + + metadata = { + "layer_type": "KVCache", + "layer_idx": layer_idx, + "offset": layer.offset, + } + # Preserve RotatingKVCache extra attributes + for attr in ("max_size", "keep", "step", "_idx"): + if hasattr(layer, attr): + metadata[attr] = getattr(layer, attr) + + return metadata + + def deserialize_layer(self, file_path: str, metadata: dict[str, Any]) -> dict: + from safetensors.numpy import load_file + + layer_idx = metadata["layer_idx"] + tensors = load_file(file_path) + + result = { + "keys": tensors[f"layer_{layer_idx}_keys"], + "values": tensors[f"layer_{layer_idx}_values"], + "offset": metadata["offset"], + } + for attr in ("max_size", "keep", "step", "_idx"): + if attr in metadata: + result[attr] = metadata[attr] + return result + + +class ArraysCacheSerializer(LayerSerializer): + """Serializer for ArraysCache (Mamba/linear attention) layers. + + Handles layers with .state attribute containing a list of arrays. + """ + + def serialize_layer( + self, layer: Any, layer_idx: int, file_path: str + ) -> dict[str, Any]: + from safetensors.numpy import save_file + + state = layer.state + tensors = {} + for i, arr in enumerate(state): + tensors[f"layer_{layer_idx}_state_{i}"] = np.array(arr) + + save_file(tensors, file_path) + + return { + "layer_type": "ArraysCache", + "layer_idx": layer_idx, + "num_arrays": len(state), + } + + def deserialize_layer(self, file_path: str, metadata: dict[str, Any]) -> dict: + from safetensors.numpy import load_file + + layer_idx = metadata["layer_idx"] + num_arrays = metadata["num_arrays"] + tensors = load_file(file_path) + + state = [] + for i in range(num_arrays): + state.append(tensors[f"layer_{layer_idx}_state_{i}"]) + return {"state": state} + + +def get_serializer_for_layer(layer: Any) -> LayerSerializer: + """Return the appropriate serializer for a cache layer. + + Dispatches based on duck-typing: + - If layer has .keys and .values and .offset -> KVCacheSerializer + - If layer has .state and it's a list -> ArraysCacheSerializer + + Raises ValueError for unsupported layer types. + """ + if hasattr(layer, "keys") and hasattr(layer, "values") and hasattr(layer, "offset"): + return KVCacheSerializer() + if hasattr(layer, "state") and isinstance(getattr(layer, "state", None), list): + return ArraysCacheSerializer() + raise ValueError( + f"Unsupported cache layer type: {type(layer).__name__}. " + f"Supported: {list(SERIALIZER_SUPPORT_MATRIX.keys())}" + ) + + +class SSDCacheTier: + """Cold-tier disk cache for KV cache entries. + + Manages a SQLite-indexed on-disk cache directory. Evicted RAM entries + are spilled here via an async writer thread. Cold-tier fetches reload + from disk asynchronously with RAM budget reservation. + + Directory layout:: + + cache_dir/ + index.db # SQLite metadata index + data/ # safetensors files per entry + {hash}/ # one directory per entry + layer_0.safetensors + layer_1.safetensors + manifest.json # per-entry layer metadata + """ + + def __init__(self, config: SSDCacheConfig) -> None: + self._config = config + + if config.cache_dir is None: + raise ValueError("SSDCacheConfig.cache_dir must be set") + + self._cache_dir = config.cache_dir + self._data_dir = os.path.join(self._cache_dir, "data") + + # Create directory structure + os.makedirs(self._cache_dir, mode=config.dir_permissions, exist_ok=True) + os.makedirs(self._data_dir, mode=config.dir_permissions, exist_ok=True) + + # Open SQLite index + self._index = SSDIndex(self._cache_dir) + + # Stats + self._stats = SSDCacheStats() + self._lock = threading.Lock() + self._closed = False + + # Spill queue and writer thread + self._spill_queue: queue.Queue = queue.Queue(maxsize=config.spill_queue_size) + self._writer_thread: threading.Thread | None = None + self._writer_stop = threading.Event() + + @staticmethod + def _entry_hash(tokens: tuple[int, ...]) -> str: + """Compute deterministic hash for a token sequence.""" + return _tokens_hash(tokens) + + def get_stats(self) -> dict: + """Return current SSD cache statistics.""" + return self._stats.to_dict() + + def start_writer(self) -> None: + """Start the background spill writer thread.""" + if self._writer_thread is not None: + return + self._writer_stop.clear() + self._writer_thread = threading.Thread( + target=self._writer_loop, daemon=True, name="ssd-cache-writer" + ) + self._writer_thread.start() + logger.info("[ssd_cache] writer thread started") + + def _writer_loop(self) -> None: + """Background loop: drain spill queue and write to disk.""" + while not self._writer_stop.is_set(): + try: + item = self._spill_queue.get(timeout=0.5) + except queue.Empty: + continue + + if item is None: # Poison pill for shutdown + break + + tokens_key, cache_layers, memory_bytes = item + try: + self._write_entry(tokens_key, cache_layers, memory_bytes) + except Exception: + logger.exception( + f"[ssd_cache] failed to write entry " f"({len(tokens_key)} tokens)" + ) + + def enqueue_spill( + self, + tokens: tuple[int, ...], + cache: list[Any], + memory_bytes: int, + ) -> bool: + """Enqueue a cache entry for async spill to SSD. + + Returns True if enqueued, False if queue is full (entry dropped). + """ + try: + self._spill_queue.put_nowait((tokens, cache, memory_bytes)) + return True + except queue.Full: + logger.warning( + f"[ssd_cache] spill queue full, dropping entry " + f"({len(tokens)} tokens, {memory_bytes} bytes)" + ) + return False + + def _write_entry( + self, + tokens_key: tuple[int, ...], + cache_layers: list[Any], + memory_bytes: int, + ) -> None: + """Write a single cache entry to disk atomically. + + Uses temp-file + rename for crash consistency. + """ + import shutil + + entry_hash = self._entry_hash(tokens_key) + entry_dir = os.path.join(self._data_dir, entry_hash) + tmp_dir = entry_dir + ".tmp" + + # Clean up any leftover tmp dir from a previous crash + if os.path.exists(tmp_dir): + shutil.rmtree(tmp_dir) + + os.makedirs(tmp_dir, mode=self._config.dir_permissions, exist_ok=True) + + layer_manifests = [] + total_file_bytes = 0 + + for i, layer in enumerate(cache_layers): + serializer = get_serializer_for_layer(layer) + layer_path = os.path.join(tmp_dir, f"layer_{i}.safetensors") + metadata = serializer.serialize_layer(layer, i, layer_path) + layer_manifests.append(metadata) + + # Set file permissions + os.chmod(layer_path, self._config.file_permissions) + total_file_bytes += os.path.getsize(layer_path) + + # Write manifest + manifest = { + "num_layers": len(cache_layers), + "layers": layer_manifests, + "memory_bytes": memory_bytes, + "num_tokens": len(tokens_key), + } + manifest_path = os.path.join(tmp_dir, "manifest.json") + with open(manifest_path, "w") as f: + json.dump(manifest, f) + os.chmod(manifest_path, self._config.file_permissions) + + # Save tokens binary + tokens_path = os.path.join(tmp_dir, "tokens.bin") + arr = _array.array("i", tokens_key) + with open(tokens_path, "wb") as f: + arr.tofile(f) + os.chmod(tokens_path, self._config.file_permissions) + + # Atomic rename: tmp_dir -> entry_dir + if os.path.exists(entry_dir): + shutil.rmtree(entry_dir) + os.rename(tmp_dir, entry_dir) + + # Update index + relative_path = entry_hash + self._index.insert_entry( + tokens_key=tokens_key, + file_path=relative_path, + memory_bytes=memory_bytes, + num_tokens=len(tokens_key), + ) + + # Update stats + with self._lock: + self._stats.spill_count += 1 + self._stats.spill_bytes += total_file_bytes + + logger.debug( + f"[ssd_cache] spilled entry: {len(tokens_key)} tokens, " + f"{total_file_bytes} bytes on disk" + ) + + # Enforce capacity after write + self._enforce_capacity() + + def lookup_ssd(self, tokens: tuple[int, ...]) -> dict | None: + """Synchronous check whether tokens exist in SSD tier. + + This is fast (SQLite lookup only, no disk I/O for data). + Called from synchronous fetch() to report an SSD candidate. + + Returns: + Dict with entry metadata if found, None otherwise. + """ + result = self._index.lookup_exact(tokens) + if result is not None: + return result + return None + + def lookup_ssd_prefix(self, tokens: tuple[int, ...]) -> dict | None: + """Find the longest prefix match in the SSD tier. + + Returns the longest-prefix entry metadata or None. + """ + results = self._index.lookup_prefix(tokens) + if results: + return results[0] # Already sorted by num_tokens DESC + return None + + async def async_promote( + self, + tokens: tuple[int, ...], + reserve_budget_fn, + release_budget_fn, + ) -> list | None: + """Promote an entry from SSD to RAM asynchronously. + + CRITICAL: Reserves RAM budget BEFORE the disk read, to avoid + thrash when multiple promotions race. + + Args: + tokens: Token sequence to promote. + reserve_budget_fn: Callable(nbytes) -> bool. Must return True + if budget is available and reserved, False otherwise. + release_budget_fn: Callable(nbytes) -> None. Called to release + budget on failure. + + Returns: + List of deserialized cache layers, or None if promotion failed. + """ + import asyncio + + # Step 1: Look up metadata (fast, SQLite) + meta = self._index.lookup_exact(tokens) + if meta is None: + with self._lock: + self._stats.ssd_misses += 1 + return None + + memory_bytes = meta["memory_bytes"] + + # Step 2: Reserve RAM budget BEFORE disk read + if not reserve_budget_fn(memory_bytes): + with self._lock: + self._stats.promotion_failures += 1 + logger.warning( + f"[ssd_cache] promotion denied: cannot reserve " + f"{memory_bytes} bytes RAM budget" + ) + return None + + # Step 3: Read from disk (in thread pool to avoid blocking event loop) + # Use shield-and-await-on-cancel per CLAUDE.md Golden Rule #4: + # budget must be released even if the calling task is cancelled. + t0 = time.time() + worker = asyncio.ensure_future( + asyncio.to_thread(self._read_entry, tokens, meta["file_path"]) + ) + try: + cache_layers = await asyncio.shield(worker) + except asyncio.CancelledError: + # Caller cancelled — still need to wait for the disk read + # to finish, then release the budget + try: + await worker + except Exception: + pass + release_budget_fn(memory_bytes) + raise + except Exception: + # Release budget on read failure + release_budget_fn(memory_bytes) + with self._lock: + self._stats.promotion_failures += 1 + logger.exception( + f"[ssd_cache] failed to read entry from disk " + f"({meta['num_tokens']} tokens)" + ) + return None + + if cache_layers is None: + # Corrupted entry — release budget, quarantine entry + release_budget_fn(memory_bytes) + with self._lock: + self._stats.promotion_failures += 1 + return None + + dt = time.time() - t0 + total_read_bytes = sum( + os.path.getsize( + os.path.join( + self._data_dir, meta["file_path"], f"layer_{i}.safetensors" + ) + ) + for i in range(len(cache_layers)) + if os.path.exists( + os.path.join( + self._data_dir, meta["file_path"], f"layer_{i}.safetensors" + ) + ) + ) + + with self._lock: + self._stats.ssd_hits += 1 + self._stats.reload_latency_sum += dt + self._stats.reload_bytes += total_read_bytes + + # Update access time in index + self._index.touch(tokens) + + logger.info( + f"[ssd_cache] promoted entry: {meta['num_tokens']} tokens, " + f"{total_read_bytes} bytes, {dt*1000:.1f}ms" + ) + + return cache_layers + + def _read_entry(self, tokens: tuple[int, ...], relative_path: str) -> list | None: + """Read a cache entry from disk. Called from thread pool. + + Returns list of deserialized layer dicts, or None on corruption. + """ + entry_dir = os.path.join(self._data_dir, relative_path) + manifest_path = os.path.join(entry_dir, "manifest.json") + + try: + with open(manifest_path) as f: + manifest = json.load(f) + except (json.JSONDecodeError, OSError) as e: + logger.warning(f"[ssd_cache] corrupt manifest for {relative_path}: {e}") + self._quarantine_entry(tokens, relative_path) + return None + + cache_layers = [] + for layer_meta in manifest["layers"]: + layer_idx = layer_meta["layer_idx"] + layer_path = os.path.join(entry_dir, f"layer_{layer_idx}.safetensors") + layer_type = layer_meta["layer_type"] + + try: + if layer_type in ("KVCache", "RotatingKVCache"): + serializer = KVCacheSerializer() + elif layer_type in ("ArraysCache", "MambaCache"): + serializer = ArraysCacheSerializer() + else: + logger.warning( + f"[ssd_cache] unknown layer type {layer_type}, skipping" + ) + self._quarantine_entry(tokens, relative_path) + return None + + layer_data = serializer.deserialize_layer(layer_path, layer_meta) + cache_layers.append(layer_data) + except Exception as e: + logger.warning( + f"[ssd_cache] corrupt layer {layer_idx} in {relative_path}: {e}" + ) + self._quarantine_entry(tokens, relative_path) + return None + + return cache_layers + + def _quarantine_entry(self, tokens: tuple[int, ...], relative_path: str) -> None: + """Move a corrupt entry to quarantine and remove from index.""" + entry_dir = os.path.join(self._data_dir, relative_path) + quarantine_dir = os.path.join(self._cache_dir, "quarantine", relative_path) + + try: + if os.path.exists(entry_dir): + os.makedirs( + os.path.dirname(quarantine_dir), + mode=self._config.dir_permissions, + exist_ok=True, + ) + os.rename(entry_dir, quarantine_dir) + logger.warning( + f"[ssd_cache] quarantined corrupt entry: {relative_path}" + ) + except OSError as e: + logger.warning(f"[ssd_cache] failed to quarantine {relative_path}: {e}") + + self._index.delete_entry(tokens) + + def _enforce_capacity(self) -> None: + """Evict oldest SSD entries until within capacity limits. + + Called after each spill write. Removes entries by LRU order + until both entry count and total bytes are within bounds. + """ + import shutil + + while True: + entry_count = self._index.get_entry_count() + total_bytes = self._index.get_total_bytes() + + needs_evict = ( + entry_count > self._config.max_entries + or total_bytes > self._config.max_size_bytes + ) + if not needs_evict: + break + + lru = self._index.get_lru(limit=1) + if not lru: + break + + victim = lru[0] + victim_tokens = _blob_to_tokens(victim["tokens_blob"]) + victim_dir = os.path.join(self._data_dir, victim["file_path"]) + + # Delete data files + if os.path.exists(victim_dir): + shutil.rmtree(victim_dir) + + # Delete from index + self._index.delete_entry(victim_tokens) + + logger.debug( + f"[ssd_cache] disk LRU evicted: {victim['num_tokens']} tokens, " + f"{victim['memory_bytes']} bytes" + ) + + def reconcile(self) -> int: + """Reconcile index with files on disk. + + Removes index entries whose data files are missing. + Removes data directories not in the index. + + Returns number of entries cleaned up. + """ + import shutil + + cleaned = 0 + + # Phase 1: Remove index entries with missing data dirs + all_entries = self._index.all_entries() + for entry in all_entries: + entry_dir = os.path.join(self._data_dir, entry["file_path"]) + manifest_path = os.path.join(entry_dir, "manifest.json") + if not os.path.isdir(entry_dir) or not os.path.exists(manifest_path): + tokens = _blob_to_tokens(entry["tokens_blob"]) + self._index.delete_entry(tokens) + cleaned += 1 + logger.info( + f"[ssd_cache] reconcile: removed orphaned index entry " + f"({entry['num_tokens']} tokens, path={entry['file_path']})" + ) + + # Phase 2: Remove data directories not in the index + if os.path.isdir(self._data_dir): + indexed_hashes = {e["file_path"] for e in self._index.all_entries()} + for entry_name in os.listdir(self._data_dir): + entry_path = os.path.join(self._data_dir, entry_name) + if ( + os.path.isdir(entry_path) + and entry_name not in indexed_hashes + and not entry_name.endswith(".tmp") + ): + shutil.rmtree(entry_path) + cleaned += 1 + logger.info( + f"[ssd_cache] reconcile: removed orphaned data dir " + f"{entry_name}" + ) + + if cleaned > 0: + logger.info(f"[ssd_cache] reconciliation cleaned {cleaned} entries") + + return cleaned + + def close(self) -> None: + """Close the SSD cache tier and release resources.""" + if self._closed: + return + self._closed = True + + # Stop writer thread + self._writer_stop.set() + if self._writer_thread is not None: + try: + self._spill_queue.put_nowait(None) # Poison pill + except queue.Full: + pass + self._writer_thread.join(timeout=5.0) + self._writer_thread = None + + self._index.close() + logger.info("[ssd_cache] SSDCacheTier closed") diff --git a/vllm_mlx/tool_parsers/__init__.py b/vllm_mlx/tool_parsers/__init__.py index cd76ad418..32f0503c3 100644 --- a/vllm_mlx/tool_parsers/__init__.py +++ b/vllm_mlx/tool_parsers/__init__.py @@ -62,11 +62,37 @@ from .harmony_tool_parser import HarmonyToolParser from .minimax_tool_parser import MiniMaxToolParser + +def get_parser_stop_tokens( + parser_name: str | None, + user_stops: list[str] | None, +) -> list[str]: + """Merge user-supplied stops with parser-declared extras (deduped). + + Some models declare end-of-generation tokens beyond the tokenizer's default + eos set — e.g. Gemma 4's ``<|tool_response>`` which signals the runtime's + turn after a tool call. Parsers expose those via ``extra_stop_tokens``. + """ + stops = list(user_stops or []) + if not parser_name: + return stops + try: + parser_cls = ToolParserManager.get_tool_parser(parser_name) + except (KeyError, ImportError): + return stops + for s in getattr(parser_cls, "extra_stop_tokens", []): + if s not in stops: + stops.append(s) + return stops + + __all__ = [ # Base classes "ToolParser", "ToolParserManager", "ExtractedToolCallInformation", + # Helpers + "get_parser_stop_tokens", # Specific parsers "AutoToolParser", "Gemma4ToolParser", diff --git a/vllm_mlx/tool_parsers/abstract_tool_parser.py b/vllm_mlx/tool_parsers/abstract_tool_parser.py index a76f487e7..47ced2303 100644 --- a/vllm_mlx/tool_parsers/abstract_tool_parser.py +++ b/vllm_mlx/tool_parsers/abstract_tool_parser.py @@ -50,6 +50,12 @@ class ToolParser(ABC): # without needing conversion to text format. SUPPORTS_NATIVE_TOOL_FORMAT: bool = False + # Extra stop tokens specific to this parser's model format. The server + # merges these into the request's `stop` list when the parser is active, + # so the model halts on format-specific EOG markers (e.g. Gemma 4's + # <|tool_response>) that are not part of the tokenizer's default eos set. + extra_stop_tokens: list[str] = [] + @classmethod def supports_native_format(cls) -> bool: """ diff --git a/vllm_mlx/tool_parsers/auto_tool_parser.py b/vllm_mlx/tool_parsers/auto_tool_parser.py index 37ab10d74..c759e33d7 100644 --- a/vllm_mlx/tool_parsers/auto_tool_parser.py +++ b/vllm_mlx/tool_parsers/auto_tool_parser.py @@ -55,6 +55,8 @@ class AutoToolParser(ToolParser): NEMOTRON_PARAM_PATTERN = re.compile( r"]+)>\s*(.*?)\s*", re.DOTALL ) + BARE_BRACKET_PATTERN = re.compile(r"\[(\w+)\((\{.*?\})\)\]", re.DOTALL) + BARE_BRACKET_PARTIAL_PATTERN = re.compile(r"\[\w+\($") def extract_tool_calls( self, model_output: str, request: dict[str, Any] | None = None @@ -150,7 +152,35 @@ def extract_tool_calls( if bracket_matches: cleaned_text = self.QWEN_BRACKET_PATTERN.sub("", cleaned_text).strip() - # 4. Try Nemotron pattern (before Qwen XML as it's more specific) + # 4. Try bare bracket format: [func({...})] + bare_matches = self.BARE_BRACKET_PATTERN.findall(cleaned_text) + for name, args_str in bare_matches: + try: + arguments = json.loads(args_str) + tool_calls.append( + { + "id": generate_tool_id(), + "name": name.strip(), + "arguments": ( + json.dumps(arguments, ensure_ascii=False) + if isinstance(arguments, dict) + else str(arguments) + ), + } + ) + except json.JSONDecodeError: + tool_calls.append( + { + "id": generate_tool_id(), + "name": name.strip(), + "arguments": args_str, + } + ) + + if bare_matches: + cleaned_text = self.BARE_BRACKET_PATTERN.sub("", cleaned_text).strip() + + # 5. Try Nemotron pattern (before Qwen XML as it's more specific) nemotron_matches = self.NEMOTRON_PATTERN.findall(cleaned_text) for name, params_block in nemotron_matches: params = self.NEMOTRON_PARAM_PATTERN.findall(params_block) @@ -166,7 +196,7 @@ def extract_tool_calls( if nemotron_matches: cleaned_text = self.NEMOTRON_PATTERN.sub("", cleaned_text).strip() - # 5. Try Qwen/Hermes XML pattern + # 6. Try Qwen/Hermes XML pattern xml_matches = self.QWEN_XML_PATTERN.findall(cleaned_text) for match in xml_matches: try: @@ -191,7 +221,7 @@ def extract_tool_calls( if xml_matches: cleaned_text = self.QWEN_XML_PATTERN.sub("", cleaned_text).strip() - # 6. Try Llama pattern + # 7. Try Llama pattern llama_matches = self.LLAMA_PATTERN.findall(cleaned_text) for name, args_str in llama_matches: try: @@ -219,7 +249,7 @@ def extract_tool_calls( if llama_matches: cleaned_text = self.LLAMA_PATTERN.sub("", cleaned_text).strip() - # 7. Fallback: Try raw JSON + # 8. Fallback: Try raw JSON if not tool_calls: raw_calls = self._parse_raw_json_tool_calls(cleaned_text) if raw_calls: @@ -339,11 +369,24 @@ def extract_tool_calls_streaming( "<|tool_call>", self.MISTRAL_TOKEN, "[Calling tool:", + "[", "", " around. +# This fires when a schema uses a nullable type like ["string","null"] or an +# enum field without explicit "type": the template takes the non-STRING branch +# and emits the value raw. Preceded by a value-position separator (: [ ,), a +# word starting with a letter, followed by ,/}/]. JSON literals true/false/null +# are filtered out inside the substitution. Ref: llama.cpp PR #21327. +_BARE_VALUE = re.compile(r"(?<=[:\[,])(\s*)([A-Za-z_][\w\-]*)(?=\s*[,}\]])") +_JSON_LITERALS = frozenset({"true", "false", "null"}) + # Max arg block length to prevent runaway parsing on malformed input (1 MB) _MAX_ARG_BLOCK_LEN = 1_048_576 @@ -85,15 +94,27 @@ def _find_balanced_brace(text: str, start: int) -> int: return -1 +def _quote_bare_value(m: re.Match) -> str: + """Substitution callback for _BARE_VALUE — quotes bare identifiers that + are not JSON literals (true/false/null).""" + ws, word = m.group(1), m.group(2) + if word in _JSON_LITERALS: + return m.group(0) + return f'{ws}"{word}"' + + def _gemma4_args_to_json(text: str) -> str: """Convert Gemma 4 tool call args to valid JSON. - Three-step conversion (ORDER MATTERS): + Four-step conversion (ORDER MATTERS): 1. Extract <|"|>-delimited strings into numbered \\x00N\\x00 placeholders. This protects string contents from step 2's bare-key quoting -- without this, a string value like "key: value" would be corrupted. 2. Quote bare keys (word: -> "word":) now that strings are safe. - 3. Restore placeholders as properly JSON-escaped strings via json.dumps(). + 3. Quote bare string VALUES that the template emitted without <|"|> + wrappers. Happens with nullable/enum schemas where the STRING branch + of the template isn't taken. + 4. Restore placeholders as properly JSON-escaped strings via json.dumps(). Uses a single re.sub pass (O(len(text))) instead of per-placeholder replace. """ strings: list[str] = [] @@ -108,7 +129,10 @@ def _capture(m: re.Match) -> str: # Step 2: Quote bare keys text = _BARE_KEY.sub(r'"\1":', text) - # Step 3: Restore captured strings as properly escaped JSON strings + # Step 3: Quote bare string values (nullable / enum-without-type schemas) + text = _BARE_VALUE.sub(_quote_bare_value, text) + + # Step 4: Restore captured strings as properly escaped JSON strings def _restore(m: re.Match) -> str: idx = int(m.group(1)) return json.dumps(strings[idx]) if idx < len(strings) else m.group(0) @@ -133,6 +157,13 @@ class Gemma4ToolParser(ToolParser): Used when --enable-auto-tool-choice --tool-call-parser gemma4 are set. """ + # The chat template renders <|tool_response> (token 50) when the assistant + # emits a tool call without its own tool_responses block — it's the signal + # that it's the runtime's turn, not the model's. Treat it as EOG so the + # model doesn't keep generating past the tool call. + # Ref: llama.cpp PR #21418. + extra_stop_tokens = ["<|tool_response>"] + def extract_tool_calls( self, model_output: str, request: dict[str, Any] | None = None ) -> ExtractedToolCallInformation: diff --git a/vllm_mlx/tool_parsers/kimi_tool_parser.py b/vllm_mlx/tool_parsers/kimi_tool_parser.py index c6c0de4a5..d494ce50b 100644 --- a/vllm_mlx/tool_parsers/kimi_tool_parser.py +++ b/vllm_mlx/tool_parsers/kimi_tool_parser.py @@ -7,7 +7,6 @@ - <|tool_call_begin|>func_name:0<|tool_call_argument_begin|>{...}<|tool_call_end|> """ -import json import re import uuid from collections.abc import Sequence @@ -89,28 +88,17 @@ def extract_tool_calls( matches = self.TOOL_CALL_PATTERN.findall(model_output) for match in matches: func_id, func_args = match - # func_id format: functions.get_weather:0 or get_weather:0 - func_name = func_id.split(":")[-2] if ":" in func_id else func_id + # func_id format: functions.get_weather:0 or get_weather:0 or get_weather + func_name = func_id.rsplit(":", 1)[0] if ":" in func_id else func_id func_name = func_name.split(".")[-1] # Remove 'functions.' prefix - try: - # Validate JSON - json.loads(func_args) - tool_calls.append( - { - "id": generate_tool_id(), - "name": func_name.strip(), - "arguments": func_args.strip(), - } - ) - except json.JSONDecodeError: - tool_calls.append( - { - "id": generate_tool_id(), - "name": func_name.strip(), - "arguments": func_args.strip(), - } - ) + tool_calls.append( + { + "id": generate_tool_id(), + "name": func_name.strip(), + "arguments": func_args.strip(), + } + ) if tool_calls: return ExtractedToolCallInformation( diff --git a/vllm_mlx/tool_parsers/qwen_tool_parser.py b/vllm_mlx/tool_parsers/qwen_tool_parser.py index e235a3c7d..64c1ed5ed 100644 --- a/vllm_mlx/tool_parsers/qwen_tool_parser.py +++ b/vllm_mlx/tool_parsers/qwen_tool_parser.py @@ -58,6 +58,8 @@ class QwenToolParser(ToolParser): Used when --enable-auto-tool-choice --tool-call-parser qwen are set. """ + SUPPORTS_NATIVE_TOOL_FORMAT = True + # Pattern for XML-style: {"json"} XML_PATTERN = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL) @@ -70,6 +72,9 @@ class QwenToolParser(ToolParser): # Pattern for parameter extraction: value PARAM_PATTERN = re.compile(r"]+)>\s*(.*?)\s*", re.DOTALL) + # Pattern for empty wrappers left after function extraction + EMPTY_TOOL_CALL = re.compile(r"\s*", re.DOTALL) + def extract_tool_calls( self, model_output: str, request: dict[str, Any] | None = None ) -> ExtractedToolCallInformation: @@ -164,6 +169,8 @@ def extract_tool_calls( cleaned_text = self.FUNCTION_PATTERN.sub("", cleaned_text).strip() if tool_calls: + # Clean up empty wrappers left after function extraction + cleaned_text = self.EMPTY_TOOL_CALL.sub("", cleaned_text).strip() return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, @@ -272,9 +279,11 @@ def extract_tool_calls_streaming( return None - # If we're in a tool call, accumulate and parse at the end - # For simplicity, return None during accumulation - if "" in delta_text or ")]" in delta_text: + # If we're in a tool call, accumulate and parse at the end. + # Check current_text (accumulated), not delta_text — closing markers + # like ")]" or "" often span token boundaries and may + # never appear within a single delta chunk. + if "" in current_text or ")]" in current_text: # Tool call complete, parse the whole thing result = self.extract_tool_calls(current_text) if result.tools_called: