diff --git a/docs/guides/multimodal.md b/docs/guides/multimodal.md index b916079d..6d1177c1 100644 --- a/docs/guides/multimodal.md +++ b/docs/guides/multimodal.md @@ -162,6 +162,18 @@ curl http://localhost:8000/v1/chat/completions \ | Local file | `{"type": "video", "video": "/path/to/video.mp4"}` | | Base64 | `{"type": "video_url", "video_url": {"url": "data:video/mp4;base64,..."}}` | +### Remote URL safety + +Remote image, video, and audio URLs are checked before each fetch and redirect +hop. URLs that resolve to localhost, link-local, private, or otherwise +non-global addresses are rejected with a generic client error while detailed +diagnostics stay in server logs. + +This validation does not pin the IP address used by the later HTTP transport +connection. In environments where DNS rebinding or split-horizon DNS is in +scope, run vllm-mlx behind network egress controls or fetch media through a +trusted proxy that enforces the destination policy at connect time. + ## Python API ```python diff --git a/tests/test_mllm.py b/tests/test_mllm.py index ff22246e..443dadbd 100644 --- a/tests/test_mllm.py +++ b/tests/test_mllm.py @@ -170,6 +170,17 @@ def test_validate_url_safety_allows_public_ip(self): _validate_url_safety("https://8.8.8.8/image.jpg") + def test_unsafe_remote_url_error_has_safe_public_message(self): + """Public safety errors should not disclose resolved hosts or IPs.""" + from vllm_mlx.models.mllm import UnsafeRemoteURLError, _validate_url_safety + + with pytest.raises(UnsafeRemoteURLError) as exc_info: + _validate_url_safety("http://169.254.169.254/latest/meta-data/") + + assert "169.254.169.254" in str(exc_info.value) + assert exc_info.value.public_message == "Remote media URL is not allowed" + assert "169.254.169.254" not in exc_info.value.public_message + def test_request_with_safe_redirects_blocks_unsafe_redirect(self, monkeypatch): """Test that redirect hops are validated before a second request.""" from vllm_mlx.models import mllm diff --git a/tests/test_responses_api.py b/tests/test_responses_api.py index 78f8c4c2..1c4d9d98 100644 --- a/tests/test_responses_api.py +++ b/tests/test_responses_api.py @@ -152,6 +152,50 @@ def test_responses_applies_server_default_chat_template_kwargs(self, client): "enable_thinking": False } + def test_streaming_responses_validates_remote_media_once(self, client, monkeypatch): + import vllm_mlx.server as srv + + engine = _mock_engine() + engine._stream_outputs = [_stream_output("Hello there", finish_reason="stop")] + srv._engine = engine + + validate_calls = [] + + def fake_validate(messages): + validate_calls.append(messages) + + def fake_responses_to_chat(_request): + return srv.ChatCompletionRequest( + model="test-model", + messages=[ + srv.Message( + role="user", + content=[ + {"type": "text", "text": "describe this"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.png"}, + }, + ], + ) + ], + max_tokens=8, + stream=True, + ) + + monkeypatch.setattr(srv, "_validate_remote_media_urls", fake_validate) + monkeypatch.setattr( + srv, "_responses_request_to_chat_request", fake_responses_to_chat + ) + + resp = client.post( + "/v1/responses", + json={"model": "test-model", "input": "describe this", "stream": True}, + ) + + assert resp.status_code == 200 + assert len(validate_calls) == 1 + def test_responses_request_kwargs_override_server_defaults(self, client): import vllm_mlx.server as srv diff --git a/tests/test_server.py b/tests/test_server.py index 80d42da6..9e8af536 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -2561,6 +2561,130 @@ def fake_extract(messages, preserve_native_format=False): assert response.status_code == 200 assert extract_calls["count"] == 1 + def test_chat_completion_sanitizes_remote_media_safety_errors( + self, client, monkeypatch + ): + """Unsafe remote media URLs should return a generic public error.""" + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-mllm" + is_mllm = True + preserve_native_tool_format = False + + async def chat(self, messages, **kwargs): # pragma: no cover + raise AssertionError("unsafe URL should fail during preparation") + + async def fake_acquire(_raw_request, **_kwargs): + return FakeEngine() + + async def fake_release(*_args, **_kwargs): + return None + + monkeypatch.setattr(server, "_acquire_default_engine_for_request", fake_acquire) + monkeypatch.setattr(server, "_release_engine_for_request", fake_release) + monkeypatch.setattr(server, "_model_name", "test-model") + monkeypatch.setattr(server, "_default_timeout", 30.0) + monkeypatch.setattr(server, "_default_max_tokens", 128) + monkeypatch.setattr(server, "_api_key", None) + monkeypatch.setattr( + server, + "_rate_limiter", + server.RateLimiter(requests_per_minute=60, enabled=False), + ) + + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "describe this"}, + { + "type": "image_url", + "image_url": { + "url": "http://169.254.169.254/latest/meta-data/" + }, + }, + ], + } + ], + "max_tokens": 8, + }, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "Remote media URL is not allowed" + assert "169.254.169.254" not in response.text + + def test_anthropic_message_sanitizes_remote_media_safety_errors( + self, client, monkeypatch + ): + """Anthropic preparation should sanitize URL-safety failures too.""" + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-mllm" + is_mllm = True + preserve_native_tool_format = False + + async def chat(self, messages, **kwargs): # pragma: no cover + raise AssertionError("unsafe URL should fail during preparation") + + async def fake_acquire(_raw_request, **_kwargs): + return FakeEngine() + + async def fake_release(*_args, **_kwargs): + return None + + def fake_anthropic_to_openai(_anthropic_request): + return server.ChatCompletionRequest( + model="test-model", + messages=[ + server.Message( + role="user", + content=[ + {"type": "text", "text": "describe this"}, + { + "type": "image_url", + "image_url": { + "url": "http://169.254.169.254/latest/meta-data/" + }, + }, + ], + ) + ], + max_tokens=8, + ) + + monkeypatch.setattr(server, "_acquire_default_engine_for_request", fake_acquire) + monkeypatch.setattr(server, "_release_engine_for_request", fake_release) + monkeypatch.setattr(server, "anthropic_to_openai", fake_anthropic_to_openai) + monkeypatch.setattr(server, "_model_name", "test-model") + monkeypatch.setattr(server, "_default_timeout", 30.0) + monkeypatch.setattr(server, "_default_max_tokens", 128) + monkeypatch.setattr(server, "_api_key", None) + monkeypatch.setattr( + server, + "_rate_limiter", + server.RateLimiter(requests_per_minute=60, enabled=False), + ) + + response = client.post( + "/v1/messages", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "describe this"}], + "max_tokens": 8, + }, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "Remote media URL is not allowed" + assert "169.254.169.254" not in response.text + class TestChatCompletionStreamingModeSwitching: """Endpoint-level regression tests for stream/non-stream mode switching.""" diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index 8477e644..44839791 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -122,7 +122,14 @@ class FileSizeExceededError(Exception): class UnsafeRemoteURLError(ValueError): """Raised when a remote media URL targets an unsafe destination.""" - pass + def __init__( + self, + message: str, + *, + public_message: str = "Remote media URL is not allowed", + ) -> None: + super().__init__(message) + self.public_message = public_message @dataclass diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 75590d9b..201b960b 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -167,6 +167,7 @@ load_registry_config, ) from .metrics import metrics as _metrics +from .models.mllm import UnsafeRemoteURLError, _validate_url_safety, is_url from .reasoning import get_parser as get_reasoning_parser from .tool_parsers import ToolParserManager, get_parser_stop_tokens @@ -264,6 +265,8 @@ def _prepare_chat_messages( request_messages: list[Message | dict], ) -> tuple[list[dict], list, list, list, bool]: """Normalize messages and collect media once for both stream/non-stream paths.""" + _validate_remote_media_urls(request_messages) + is_mllm = bool(getattr(engine, "is_mllm", False)) preserve_native = bool(getattr(engine, "preserve_native_tool_format", False)) @@ -332,6 +335,52 @@ def _prepare_chat_messages( return messages, images, videos, audios, has_media +def _iter_remote_media_urls(messages: list[Message | dict]): + """Yield remote media URLs from OpenAI-style multimodal message content.""" + for msg in messages: + content = msg.get("content") if isinstance(msg, dict) else msg.content + if not isinstance(content, list): + continue + for item in content: + if hasattr(item, "model_dump"): + item = item.model_dump(exclude_none=True) + elif hasattr(item, "dict"): + item = {k: v for k, v in item.dict().items() if v is not None} + if not isinstance(item, dict): + continue + + item_type = item.get("type", "") + media_value = None + if item_type == "image_url": + media_value = item.get("image_url", {}) + elif item_type == "video_url": + media_value = item.get("video_url", {}) + elif item_type == "audio_url": + media_value = item.get("audio_url", {}) + elif item_type in {"image", "video", "audio"}: + media_value = item.get(item_type, item.get("url", "")) + + if isinstance(media_value, dict): + media_value = media_value.get("url", "") + if isinstance(media_value, str) and is_url(media_value): + yield media_value + + +def _validate_remote_media_urls(messages: list[Message | dict]) -> None: + """Validate remote media URLs during request preparation.""" + for url in _iter_remote_media_urls(messages): + _validate_url_safety(url) + + +def _raise_remote_media_http_error(exc: UnsafeRemoteURLError) -> None: + """Log internal URL-safety detail while returning a generic client error.""" + logger.warning( + "Blocked unsafe remote media URL: %s", + _sanitize_log_text(exc, limit=500), + ) + raise HTTPException(status_code=400, detail=exc.public_message) from exc + + def _prepare_json_logits_processor( engine: BaseEngine, messages: list[dict], @@ -1980,6 +2029,8 @@ def _build_response_object( def _prepare_responses_request( request: ResponsesRequest, + *, + validate_remote_media: bool = True, ) -> tuple[BaseEngine, ChatCompletionRequest, list[dict], dict]: """Prepare a Responses request for execution on the chat engine.""" _validate_model_name(request.model) @@ -1994,6 +2045,9 @@ def _prepare_responses_request( f"tools={len(request.tools)}" ) + if validate_remote_media: + _validate_remote_media_urls(chat_request.messages) + messages, images, videos, audios = extract_multimodal_content( chat_request.messages, preserve_native_format=engine.preserve_native_tool_format, @@ -2019,6 +2073,13 @@ def _prepare_responses_request( return engine, chat_request, messages, chat_kwargs +def _prepare_streaming_responses_request( + request: ResponsesRequest, +) -> tuple[BaseEngine, ChatCompletionRequest, list[dict], dict]: + """Prepare a streaming Responses request after eager URL validation.""" + return _prepare_responses_request(request, validate_remote_media=False) + + async def _run_responses_request( request: ResponsesRequest, raw_request: Request, @@ -2077,7 +2138,9 @@ async def _run_responses_request( async def _stream_responses_request(request: ResponsesRequest) -> AsyncIterator[str]: """Execute a Responses API request and stream SSE events incrementally.""" - engine, chat_request, messages, chat_kwargs = _prepare_responses_request(request) + engine, chat_request, messages, chat_kwargs = _prepare_streaming_responses_request( + request + ) response_id = _new_response_item_id("resp") sequence = 1 @@ -4448,11 +4511,15 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re release_on_exit = True try: - prepared = _prepare_chat_completion_invocation( - engine, - request, - effective_max_tokens, - ) + try: + prepared = _prepare_chat_completion_invocation( + engine, + request, + effective_max_tokens, + ) + except UnsafeRemoteURLError as exc: + tracker.finish(result="client_error") + _raise_remote_media_http_error(exc) if request.stream: response = StreamingResponse( @@ -4634,15 +4701,21 @@ def _get_engine_tokenizer(engine) -> object | None: ) async def create_response(request: ResponsesRequest, raw_request: Request): """Create a Responses API response.""" - if request.stream: - return StreamingResponse( - _disconnect_guard(_stream_responses_request(request), raw_request), - media_type="text/event-stream", + try: + if request.stream: + chat_request = _responses_request_to_chat_request(request) + _validate_remote_media_urls(chat_request.messages) + return StreamingResponse( + _disconnect_guard(_stream_responses_request(request), raw_request), + media_type="text/event-stream", + ) + + response_object, _persisted_messages = await _run_responses_request( + request, raw_request ) + except UnsafeRemoteURLError as exc: + _raise_remote_media_http_error(exc) - response_object, _persisted_messages = await _run_responses_request( - request, raw_request - ) if response_object is None: return Response(status_code=499) @@ -4770,6 +4843,22 @@ def _convert_anthropic_stop_reason(openai_reason: str | None) -> str: return mapping.get(openai_reason or "", "end_turn") +def _prepare_anthropic_endpoint_invocation( + engine: BaseEngine, + openai_request: ChatCompletionRequest, + effective_max_tokens: int, +) -> PreparedChatInvocation: + """Prepare Anthropic invocation and convert URL-safety errors to 400s.""" + try: + return _prepare_anthropic_invocation( + engine, + openai_request, + effective_max_tokens, + ) + except UnsafeRemoteURLError as exc: + _raise_remote_media_http_error(exc) + + @app.post( "/v1/messages", dependencies=[Depends(verify_api_key), Depends(check_rate_limit)] ) @@ -4838,7 +4927,7 @@ async def create_anthropic_message( if engine is None: return Response(status_code=499) release_on_exit = True - prepared = _prepare_anthropic_invocation( + prepared = _prepare_anthropic_endpoint_invocation( engine, openai_request, effective_max_tokens,