diff --git a/vllm_mlx/api/models.py b/vllm_mlx/api/models.py index 0ebe616eb..f020ec7e0 100644 --- a/vllm_mlx/api/models.py +++ b/vllm_mlx/api/models.py @@ -176,6 +176,7 @@ class ChatCompletionRequest(BaseModel): # MLLM-specific parameters video_fps: float | None = None video_max_frames: int | None = None + extract_audio_from_video: bool | None = None # Sampling penalties repetition_penalty: float | None = None # mlx-lm style (>1.0 penalizes) # Request timeout in seconds (None = use server default) diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 36edd91d8..629262a41 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -68,6 +68,9 @@ def serve_command(args): if args.default_top_p is not None: server._default_top_p = args.default_top_p + # Configure audio extraction from video + server._extract_audio_from_video = args.extract_audio_from_video + # Configure reasoning parser if args.reasoning_parser: try: @@ -943,6 +946,12 @@ def create_parser() -> argparse.ArgumentParser: action="store_true", help="Force load model as multimodal (vision) even if name doesn't match auto-detection patterns", ) + serve_parser.add_argument( + "--extract-audio-from-video", + action="store_true", + help="Extract audio track from video inputs and pass to the model. " + "Requires ffmpeg. Only useful for models with audio support (e.g. Gemma 4).", + ) # Generation defaults serve_parser.add_argument( "--default-temperature", diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index ec761c3f2..5b70440bf 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -802,6 +802,39 @@ def _prepare_video( ) return save_frames_to_temp(frames) + @staticmethod + def _extract_audio_from_video(video_path: str) -> str | None: + """Extract audio track from a video file as a temporary WAV file. + + Returns the path to the extracted WAV, or None if the video has no + audio track or ffmpeg is not available. + """ + import subprocess + import tempfile + + try: + tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + tmp.close() + result = subprocess.run( + [ + "ffmpeg", "-y", "-i", video_path, + "-vn", "-ar", "16000", "-ac", "1", "-f", "wav", + tmp.name, + ], + capture_output=True, + timeout=30, + ) + if result.returncode == 0: + import os + if os.path.getsize(tmp.name) > 44: # WAV header is 44 bytes + return tmp.name + # Clean up on failure + import os + os.unlink(tmp.name) + except Exception as e: + logger.debug(f"Could not extract audio from video: {e}") + return None + def _collect_video_inputs(self, messages: list[dict]) -> dict[int, list]: """Collect video inputs from messages, keyed by message index. @@ -1319,9 +1352,10 @@ def chat( from mlx_vlm import generate from mlx_vlm.prompt_utils import get_chat_template - # Extract text and images from messages - # Build chat_messages for multi-turn support WITH proper image tokens per message + # Extract text, images and audio from messages + # Build chat_messages for multi-turn support WITH proper image/audio tokens per message all_image_urls = [] # Raw URLs/paths to process later + all_audio_urls = [] # Raw audio URLs/paths to process later chat_messages = [] # List of properly formatted messages for chat template logger.info(f"MLLM.chat() called with {len(messages)} messages") @@ -1329,6 +1363,7 @@ def chat( # Pop params early so they don't leak into mlx_vlm.generate() video_fps = kwargs.pop("video_fps", DEFAULT_FPS) video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) + extract_audio_from_video = kwargs.pop("extract_audio_from_video", False) tools = kwargs.pop("tools", None) use_cache = kwargs.pop("use_cache", True) enable_thinking = kwargs.pop("enable_thinking", True) @@ -1360,6 +1395,16 @@ def chat( all_video_frames.extend(frames) total_frames += len(frames) logger.info(f"Added {len(frames)} frames from video: {vid_input}") + + # Extract audio track from video if explicitly requested + if extract_audio_from_video: + video_path = vid_input if isinstance(vid_input, str) else vid_input.get("url", vid_input.get("video_url", {}).get("url", "")) + if video_path and not video_path.startswith("data:"): + audio_path = self._extract_audio_from_video(video_path) + if audio_path: + all_audio_urls.append(audio_path) + logger.info(f"Extracted audio from video: {video_path}") + _msg_video_frame_counts[msg_idx] = total_frames # Second pass: build chat messages with image counts that include video frames @@ -1405,17 +1450,27 @@ def chat( ) msg_image_count += 1 + elif item_type == "audio_url": + aud_url = item.get("audio_url", {}) + if isinstance(aud_url, str): + all_audio_urls.append(aud_url) + else: + all_audio_urls.append(aud_url.get("url", "")) + # Add video frame count to image count for this message msg_image_count += _msg_video_frame_counts.get(msg_idx, 0) + msg_audio_count = len(all_audio_urls) - # Build properly structured message for Qwen3-VL-MoE - # Format: {"role": "...", "content": [{"type": "image"}, ..., {"type": "text", "text": "..."}]} - if msg_text or msg_image_count > 0: - if role == "user" and msg_image_count > 0: - # User message WITH images - build content array with image tokens FIRST + # Build properly structured message + # Format: {"role": "...", "content": [{"type": "image"}, ..., {"type": "audio"}, ..., {"type": "text", "text": "..."}]} + if msg_text or msg_image_count > 0 or msg_audio_count > 0: + if role == "user" and (msg_image_count > 0 or msg_audio_count > 0): + # User message WITH images/audio - build content array with media tokens FIRST content_list = [] for _ in range(msg_image_count): content_list.append({"type": "image"}) + for _ in range(msg_audio_count): + content_list.append({"type": "audio"}) content_list.append( {"type": "text", "text": msg_text, "content": msg_text} ) @@ -1443,7 +1498,7 @@ def chat( # Apply chat template directly - messages are already properly structured logger.info( - f"Applying chat template with {len(chat_messages)} messages, {len(all_images)} images" + f"Applying chat template with {len(chat_messages)} messages, {len(all_images)} images, {len(all_audio_urls)} audios" ) for i, cm in enumerate(chat_messages): content_preview = str(cm.get("content", ""))[:80] @@ -1586,6 +1641,7 @@ def chat( self.processor, formatted_prompt, all_images if all_images else None, + audio=all_audio_urls if all_audio_urls else None, max_tokens=max_tokens, temp=temperature, verbose=False, diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 1a60ee7da..41054933b 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -126,6 +126,7 @@ _default_timeout: float = 300.0 # Default request timeout in seconds (5 minutes) _default_temperature: float | None = None # Set via --default-temperature _default_top_p: float | None = None # Set via --default-top-p +_extract_audio_from_video: bool = False # Set via --extract-audio-from-video _metrics_enabled = False _FALLBACK_TEMPERATURE = 0.7 @@ -1663,6 +1664,12 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re chat_kwargs["video_fps"] = request.video_fps if request.video_max_frames: chat_kwargs["video_max_frames"] = request.video_max_frames + extract_audio = ( + request.extract_audio_from_video + if request.extract_audio_from_video is not None + else _extract_audio_from_video + ) + chat_kwargs["extract_audio_from_video"] = extract_audio # SpecPrefill: per-request overrides if request.specprefill is not None: