Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_mlx/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ class ChatCompletionRequest(BaseModel):
# MLLM-specific parameters
video_fps: float | None = None
video_max_frames: int | None = None
extract_audio_from_video: bool | None = None
# Sampling penalties
repetition_penalty: float | None = None # mlx-lm style (>1.0 penalizes)
# Request timeout in seconds (None = use server default)
Expand Down
9 changes: 9 additions & 0 deletions vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def serve_command(args):
if args.default_top_p is not None:
server._default_top_p = args.default_top_p

# Configure audio extraction from video
server._extract_audio_from_video = args.extract_audio_from_video

# Configure reasoning parser
if args.reasoning_parser:
try:
Expand Down Expand Up @@ -943,6 +946,12 @@ def create_parser() -> argparse.ArgumentParser:
action="store_true",
help="Force load model as multimodal (vision) even if name doesn't match auto-detection patterns",
)
serve_parser.add_argument(
"--extract-audio-from-video",
action="store_true",
help="Extract audio track from video inputs and pass to the model. "
"Requires ffmpeg. Only useful for models with audio support (e.g. Gemma 4).",
)
# Generation defaults
serve_parser.add_argument(
"--default-temperature",
Expand Down
72 changes: 64 additions & 8 deletions vllm_mlx/models/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,39 @@ def _prepare_video(
)
return save_frames_to_temp(frames)

@staticmethod
def _extract_audio_from_video(video_path: str) -> str | None:
"""Extract audio track from a video file as a temporary WAV file.

Returns the path to the extracted WAV, or None if the video has no
audio track or ffmpeg is not available.
"""
import subprocess
import tempfile

try:
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp.close()
result = subprocess.run(
[
"ffmpeg", "-y", "-i", video_path,
"-vn", "-ar", "16000", "-ac", "1", "-f", "wav",
tmp.name,
],
capture_output=True,
timeout=30,
)
if result.returncode == 0:
import os
if os.path.getsize(tmp.name) > 44: # WAV header is 44 bytes
return tmp.name
# Clean up on failure
import os
os.unlink(tmp.name)
except Exception as e:
logger.debug(f"Could not extract audio from video: {e}")
return None

def _collect_video_inputs(self, messages: list[dict]) -> dict[int, list]:
"""Collect video inputs from messages, keyed by message index.

Expand Down Expand Up @@ -1319,16 +1352,18 @@ def chat(
from mlx_vlm import generate
from mlx_vlm.prompt_utils import get_chat_template

# Extract text and images from messages
# Build chat_messages for multi-turn support WITH proper image tokens per message
# Extract text, images and audio from messages
# Build chat_messages for multi-turn support WITH proper image/audio tokens per message
all_image_urls = [] # Raw URLs/paths to process later
all_audio_urls = [] # Raw audio URLs/paths to process later
chat_messages = [] # List of properly formatted messages for chat template

logger.info(f"MLLM.chat() called with {len(messages)} messages")

# Pop params early so they don't leak into mlx_vlm.generate()
video_fps = kwargs.pop("video_fps", DEFAULT_FPS)
video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES)
extract_audio_from_video = kwargs.pop("extract_audio_from_video", False)
tools = kwargs.pop("tools", None)
use_cache = kwargs.pop("use_cache", True)
enable_thinking = kwargs.pop("enable_thinking", True)
Expand Down Expand Up @@ -1360,6 +1395,16 @@ def chat(
all_video_frames.extend(frames)
total_frames += len(frames)
logger.info(f"Added {len(frames)} frames from video: {vid_input}")

# Extract audio track from video if explicitly requested
if extract_audio_from_video:
video_path = vid_input if isinstance(vid_input, str) else vid_input.get("url", vid_input.get("video_url", {}).get("url", ""))
if video_path and not video_path.startswith("data:"):
audio_path = self._extract_audio_from_video(video_path)
if audio_path:
all_audio_urls.append(audio_path)
logger.info(f"Extracted audio from video: {video_path}")

_msg_video_frame_counts[msg_idx] = total_frames

# Second pass: build chat messages with image counts that include video frames
Expand Down Expand Up @@ -1405,17 +1450,27 @@ def chat(
)
msg_image_count += 1

elif item_type == "audio_url":
aud_url = item.get("audio_url", {})
if isinstance(aud_url, str):
all_audio_urls.append(aud_url)
else:
all_audio_urls.append(aud_url.get("url", ""))

# Add video frame count to image count for this message
msg_image_count += _msg_video_frame_counts.get(msg_idx, 0)
msg_audio_count = len(all_audio_urls)

# Build properly structured message for Qwen3-VL-MoE
# Format: {"role": "...", "content": [{"type": "image"}, ..., {"type": "text", "text": "..."}]}
if msg_text or msg_image_count > 0:
if role == "user" and msg_image_count > 0:
# User message WITH images - build content array with image tokens FIRST
# Build properly structured message
# Format: {"role": "...", "content": [{"type": "image"}, ..., {"type": "audio"}, ..., {"type": "text", "text": "..."}]}
if msg_text or msg_image_count > 0 or msg_audio_count > 0:
if role == "user" and (msg_image_count > 0 or msg_audio_count > 0):
# User message WITH images/audio - build content array with media tokens FIRST
content_list = []
for _ in range(msg_image_count):
content_list.append({"type": "image"})
for _ in range(msg_audio_count):
content_list.append({"type": "audio"})
content_list.append(
{"type": "text", "text": msg_text, "content": msg_text}
)
Expand Down Expand Up @@ -1443,7 +1498,7 @@ def chat(

# Apply chat template directly - messages are already properly structured
logger.info(
f"Applying chat template with {len(chat_messages)} messages, {len(all_images)} images"
f"Applying chat template with {len(chat_messages)} messages, {len(all_images)} images, {len(all_audio_urls)} audios"
)
for i, cm in enumerate(chat_messages):
content_preview = str(cm.get("content", ""))[:80]
Expand Down Expand Up @@ -1586,6 +1641,7 @@ def chat(
self.processor,
formatted_prompt,
all_images if all_images else None,
audio=all_audio_urls if all_audio_urls else None,
max_tokens=max_tokens,
temp=temperature,
verbose=False,
Expand Down
7 changes: 7 additions & 0 deletions vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
_default_timeout: float = 300.0 # Default request timeout in seconds (5 minutes)
_default_temperature: float | None = None # Set via --default-temperature
_default_top_p: float | None = None # Set via --default-top-p
_extract_audio_from_video: bool = False # Set via --extract-audio-from-video
_metrics_enabled = False

_FALLBACK_TEMPERATURE = 0.7
Expand Down Expand Up @@ -1663,6 +1664,12 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
chat_kwargs["video_fps"] = request.video_fps
if request.video_max_frames:
chat_kwargs["video_max_frames"] = request.video_max_frames
extract_audio = (
request.extract_audio_from_video
if request.extract_audio_from_video is not None
else _extract_audio_from_video
)
chat_kwargs["extract_audio_from_video"] = extract_audio

# SpecPrefill: per-request overrides
if request.specprefill is not None:
Expand Down
Loading