Skip to content

Commit a2df5a3

Browse files
authored
💥👽 Update ErrorResponse and supported vllm versions (#102)
* ⬆️ Bump vllm support Signed-off-by: Evaline Ju <[email protected]> * 👽⬆️ Breaking ErrorResponse changes with vllm 0.10.1 Signed-off-by: Evaline Ju <[email protected]> * ✨ Conform to original detectors API Signed-off-by: Evaline Ju <[email protected]> * 🔧 Update Dockerfile vllm version Signed-off-by: Evaline Ju <[email protected]> * 🥅 Handle request validation Signed-off-by: Evaline Ju <[email protected]> * 🥅 Keep current request validation handling for other endpoints Signed-off-by: Evaline Ju <[email protected]> * 🐛 Fix import Signed-off-by: Evaline Ju <[email protected]> * 🐛♻️ Format validation errors Signed-off-by: Evaline Ju <[email protected]> --------- Signed-off-by: Evaline Ju <[email protected]>
1 parent 5a39543 commit a2df5a3

File tree

11 files changed

+185
-103
lines changed

11 files changed

+185
-103
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ ARG BASE_UBI_IMAGE_TAG=9.6
88
ARG PYTHON_VERSION=3.12
99

1010
### Build layer
11-
FROM quay.io/vllm/vllm-cuda:0.10.0.2 as build
11+
FROM quay.io/vllm/vllm-cuda:0.11.0.1 as build
1212

1313
ARG PYTHON_VERSION
1414
ENV PYTHON_VERSION=${PYTHON_VERSION}

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vllm-detector-adapter"
3-
version = "0.8.0"
3+
version = "0.9.0"
44
authors = [
55
{ name = "Gaurav Kumbhat", email = "[email protected]" },
66
{ name = "Evaline Ju", email = "[email protected]" },
@@ -16,8 +16,8 @@ dependencies = ["orjson>=3.10.16,<3.11"]
1616
vllm-tgis-adapter = ["vllm-tgis-adapter>=0.8.0,<0.9.0"]
1717
vllm = [
1818
# Note: vllm < 0.10.0 has issues with transformers >= 4.54.0
19-
"vllm @ git+https://github.com/vllm-project/vllm.git@v0.10.0 ; sys_platform == 'darwin'",
20-
"vllm>=0.10.0,<0.10.1 ; sys_platform != 'darwin'",
19+
"vllm @ git+https://github.com/vllm-project/vllm.git@v0.11.0 ; sys_platform == 'darwin'",
20+
"vllm>=0.10.1,<0.11.1 ; sys_platform != 'darwin'",
2121
]
2222

2323
## Dev Extra Sets ##

tests/generative_detectors/test_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,5 +237,5 @@ def test_content_analysis_errorresponse_verification(detection_base):
237237
result = asyncio.run(base_instance.content_analysis(content_request))
238238

239239
assert isinstance(result, ErrorResponse)
240-
assert result.type == "BadRequestError"
241-
assert "does not have content" in result.message
240+
assert result.error.type == "BadRequestError"
241+
assert "does not have content" in result.error.message

tests/generative_detectors/test_granite_guardian.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -414,10 +414,10 @@ def test__make_tools_request_no_tool_calls(granite_guardian_detection):
414414
)
415415
processed_request = granite_guardian_detection_instance._make_tools_request(request)
416416
assert type(processed_request) == ErrorResponse
417-
assert processed_request.code == HTTPStatus.BAD_REQUEST
417+
assert processed_request.error.code == HTTPStatus.BAD_REQUEST
418418
assert (
419419
"no assistant message was provided with tool_calls for analysis"
420-
in processed_request.message
420+
in processed_request.error.message
421421
)
422422

423423

@@ -437,9 +437,10 @@ def test__make_tools_request_random_risk(granite_guardian_detection):
437437
)
438438
processed_request = granite_guardian_detection_instance._make_tools_request(request)
439439
assert type(processed_request) == ErrorResponse
440-
assert processed_request.code == HTTPStatus.BAD_REQUEST
440+
assert processed_request.error.code == HTTPStatus.BAD_REQUEST
441441
assert (
442-
"tools analysis is not supported with given risk" in processed_request.message
442+
"tools analysis is not supported with given risk"
443+
in processed_request.error.message
443444
)
444445

445446

@@ -773,8 +774,10 @@ def test_request_to_chat_completion_request_empty_kwargs(granite_guardian_detect
773774
)
774775
)
775776
assert type(chat_request) == ErrorResponse
776-
assert chat_request.code == HTTPStatus.BAD_REQUEST
777-
assert "No risk_name or criteria_id for context analysis" in chat_request.message
777+
assert chat_request.error.code == HTTPStatus.BAD_REQUEST
778+
assert (
779+
"No risk_name or criteria_id for context analysis" in chat_request.error.message
780+
)
778781

779782

780783
def test_request_to_chat_completion_request_empty_guardian_config(
@@ -793,8 +796,10 @@ def test_request_to_chat_completion_request_empty_guardian_config(
793796
)
794797
)
795798
assert type(chat_request) == ErrorResponse
796-
assert chat_request.code == HTTPStatus.BAD_REQUEST
797-
assert "No risk_name or criteria_id for context analysis" in chat_request.message
799+
assert chat_request.error.code == HTTPStatus.BAD_REQUEST
800+
assert (
801+
"No risk_name or criteria_id for context analysis" in chat_request.error.message
802+
)
798803

799804

800805
def test_request_to_chat_completion_request_missing_risk_name_and_criteria_id(
@@ -816,8 +821,10 @@ def test_request_to_chat_completion_request_missing_risk_name_and_criteria_id(
816821
)
817822
)
818823
assert type(chat_request) == ErrorResponse
819-
assert chat_request.code == HTTPStatus.BAD_REQUEST
820-
assert "No risk_name or criteria_id for context analysis" in chat_request.message
824+
assert chat_request.error.code == HTTPStatus.BAD_REQUEST
825+
assert (
826+
"No risk_name or criteria_id for context analysis" in chat_request.error.message
827+
)
821828

822829

823830
def test_request_to_chat_completion_request_unsupported_risk_name(
@@ -839,10 +846,10 @@ def test_request_to_chat_completion_request_unsupported_risk_name(
839846
)
840847
)
841848
assert type(chat_request) == ErrorResponse
842-
assert chat_request.code == HTTPStatus.BAD_REQUEST
849+
assert chat_request.error.code == HTTPStatus.BAD_REQUEST
843850
assert (
844851
"risk_name or criteria_id foo is not compatible with context analysis"
845-
in chat_request.message
852+
in chat_request.error.message
846853
)
847854

848855

@@ -1085,10 +1092,10 @@ def test_context_analyze_unsupported_risk(
10851092
granite_guardian_detection_instance.context_analyze(context_request)
10861093
)
10871094
assert type(detection_response) == ErrorResponse
1088-
assert detection_response.code == HTTPStatus.BAD_REQUEST
1095+
assert detection_response.error.code == HTTPStatus.BAD_REQUEST
10891096
assert (
10901097
"risk_name or criteria_id boo is not compatible with context analysis"
1091-
in detection_response.message
1098+
in detection_response.error.message
10921099
)
10931100

10941101

@@ -1395,8 +1402,8 @@ def test_chat_detection_errors_on_stream(granite_guardian_detection):
13951402
granite_guardian_detection_instance.chat(chat_request)
13961403
)
13971404
assert type(detection_response) == ErrorResponse
1398-
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
1399-
assert "streaming is not supported" in detection_response.message
1405+
assert detection_response.error.code == HTTPStatus.BAD_REQUEST.value
1406+
assert "streaming is not supported" in detection_response.error.message
14001407

14011408

14021409
def test_chat_detection_errors_on_jinja_template_error(granite_guardian_detection):
@@ -1414,8 +1421,8 @@ def test_chat_detection_errors_on_jinja_template_error(granite_guardian_detectio
14141421
granite_guardian_detection_instance.chat(chat_request)
14151422
)
14161423
assert type(detection_response) == ErrorResponse
1417-
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
1418-
assert "Template error" in detection_response.message
1424+
assert detection_response.error.code == HTTPStatus.BAD_REQUEST.value
1425+
assert "Template error" in detection_response.error.message
14191426

14201427

14211428
def test_chat_detection_errors_on_undefined_jinja_error(granite_guardian_detection):
@@ -1433,8 +1440,8 @@ def test_chat_detection_errors_on_undefined_jinja_error(granite_guardian_detecti
14331440
granite_guardian_detection_instance.chat(chat_request)
14341441
)
14351442
assert type(detection_response) == ErrorResponse
1436-
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
1437-
assert "Template error" in detection_response.message
1443+
assert detection_response.error.code == HTTPStatus.BAD_REQUEST.value
1444+
assert "Template error" in detection_response.error.message
14381445

14391446

14401447
def test_risk_bank_extraction(granite_guardian_detection):

tests/generative_detectors/test_llama_guard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def test_chat_detection_with_tools(llama_guard_detection):
375375
)
376376
response = asyncio.run(llama_guard_detection_instance.chat(chat_request))
377377
assert type(response) == ErrorResponse
378-
assert response.code == HTTPStatus.NOT_IMPLEMENTED
378+
assert response.error.code == HTTPStatus.NOT_IMPLEMENTED
379379

380380

381381
def test_context_analyze(llama_guard_detection):
@@ -392,7 +392,7 @@ def test_context_analyze(llama_guard_detection):
392392
llama_guard_detection_instance.context_analyze(context_request)
393393
)
394394
assert type(response) == ErrorResponse
395-
assert response.code == HTTPStatus.NOT_IMPLEMENTED
395+
assert response.error.code == HTTPStatus.NOT_IMPLEMENTED
396396

397397

398398
def test_generation_analyze(llama_guard_detection, llama_guard_completion_response):

tests/test_protocol.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,9 @@ def test_response_from_single_content_detection_missing_content():
251251
assert type(detection_response) == ErrorResponse
252252
assert (
253253
"Choice 0 from chat completion does not have content"
254-
in detection_response.message
254+
in detection_response.error.message
255255
)
256-
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
256+
assert detection_response.error.code == HTTPStatus.BAD_REQUEST.value
257257

258258

259259
#### General detection response tests
@@ -355,9 +355,9 @@ def test_response_from_completion_response_missing_content():
355355
assert type(detection_response) == ErrorResponse
356356
assert (
357357
"Choice 1 from chat completion does not have content"
358-
in detection_response.message
358+
in detection_response.error.message
359359
)
360-
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
360+
assert detection_response.error.code == HTTPStatus.BAD_REQUEST.value
361361

362362

363363
def test_response_from_empty_string_content_detection():
@@ -388,6 +388,6 @@ def test_response_from_empty_string_content_detection():
388388
assert type(detection_response) == ErrorResponse
389389
assert (
390390
"Choice 0 from chat completion does not have content"
391-
in detection_response.message
391+
in detection_response.error.message
392392
)
393-
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
393+
assert detection_response.error.code == HTTPStatus.BAD_REQUEST.value

vllm_detector_adapter/api_server.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Standard
22
from argparse import Namespace
3+
from http import HTTPStatus
34
import inspect
45
import signal
56

67
# Third Party
78
from fastapi import Request
9+
from fastapi.exceptions import RequestValidationError
810
from fastapi.responses import JSONResponse
911
from starlette.datastructures import State
1012
from vllm.config import ModelConfig
@@ -14,7 +16,7 @@
1416
from vllm.entrypoints.logger import RequestLogger
1517
from vllm.entrypoints.openai import api_server
1618
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
17-
from vllm.entrypoints.openai.protocol import ErrorResponse
19+
from vllm.entrypoints.openai.protocol import ErrorInfo, ErrorResponse
1820
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
1921
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
2022
from vllm.utils import FlexibleArgumentParser, is_valid_ipv6_address, set_ulimit
@@ -41,6 +43,7 @@
4143
# Third Party
4244
from vllm.reasoning import ReasoningParserManager
4345

46+
4447
TIMEOUT_KEEP_ALIVE = 5 # seconds
4548

4649
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
@@ -162,6 +165,37 @@ def signal_handler(*_) -> None:
162165
# Use vllm build_app which adds middleware
163166
app = api_server.build_app(args)
164167

168+
# Override exception handler to flatten errors for detectors API
169+
@app.exception_handler(RequestValidationError)
170+
async def validation_exception_handler(
171+
request: Request, exc: RequestValidationError
172+
):
173+
exc_str = str(exc)
174+
errors_str = str(exc.errors())
175+
message = None
176+
if exc.errors() and errors_str and errors_str != exc_str:
177+
message = f"{exc_str} {errors_str}"
178+
else:
179+
message = exc_str
180+
181+
error_info = ErrorInfo(
182+
message=message,
183+
type=HTTPStatus.BAD_REQUEST.phrase,
184+
code=HTTPStatus.BAD_REQUEST,
185+
)
186+
187+
if request.url.path.startswith("/api/v1/text"):
188+
# Flatten detectors API request validation errors
189+
return JSONResponse(
190+
content=error_info.model_dump(), status_code=HTTPStatus.BAD_REQUEST
191+
)
192+
else:
193+
# vLLM general request validation error handling
194+
err = ErrorResponse(error=error_info)
195+
return JSONResponse(
196+
content=err.model_dump(), status_code=HTTPStatus.BAD_REQUEST
197+
)
198+
165199
# api_server.init_app_state takes vllm_config
166200
# ref. https://github.com/vllm-project/vllm/pull/16572
167201
if hasattr(engine_client, "get_vllm_config"):
@@ -213,9 +247,9 @@ async def create_chat_detection(request: ChatDetectionRequest, raw_request: Requ
213247
detector_response = await chat_detection(raw_request).chat(request, raw_request)
214248

215249
if isinstance(detector_response, ErrorResponse):
216-
# ErrorResponse includes code and message, corresponding to errors for the detectorAPI
217250
return JSONResponse(
218-
content=detector_response.model_dump(), status_code=detector_response.code
251+
content=detector_response.error.model_dump(),
252+
status_code=detector_response.error.code,
219253
)
220254

221255
elif isinstance(detector_response, DetectionResponse):
@@ -235,9 +269,9 @@ async def create_context_doc_detection(
235269
)
236270

237271
if isinstance(detector_response, ErrorResponse):
238-
# ErrorResponse includes code and message, corresponding to errors for the detectorAPI
239272
return JSONResponse(
240-
content=detector_response.model_dump(), status_code=detector_response.code
273+
content=detector_response.error.model_dump(),
274+
status_code=detector_response.error.code,
241275
)
242276

243277
elif isinstance(detector_response, DetectionResponse):
@@ -256,9 +290,9 @@ async def create_contents_detection(
256290
request, raw_request
257291
)
258292
if isinstance(detector_response, ErrorResponse):
259-
# ErrorResponse includes code and message, corresponding to errors for the detectorAPI
260293
return JSONResponse(
261-
content=detector_response.model_dump(), status_code=detector_response.code
294+
content=detector_response.error.model_dump(),
295+
status_code=detector_response.error.code,
262296
)
263297

264298
elif isinstance(detector_response, ContentsDetectionResponse):
@@ -277,9 +311,9 @@ async def create_generation_detection(
277311
request, raw_request
278312
)
279313
if isinstance(detector_response, ErrorResponse):
280-
# ErrorResponse includes code and message, corresponding to errors for the detectorAPI
281314
return JSONResponse(
282-
content=detector_response.model_dump(), status_code=detector_response.code
315+
content=detector_response.error.model_dump(),
316+
status_code=detector_response.error.code,
283317
)
284318

285319
elif isinstance(detector_response, DetectionResponse):

vllm_detector_adapter/generative_detectors/base.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.entrypoints.openai.protocol import (
1313
ChatCompletionRequest,
1414
ChatCompletionResponse,
15+
ErrorInfo,
1516
ErrorResponse,
1617
)
1718
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
@@ -156,9 +157,11 @@ def preprocess_request( # noqa: F811
156157
# Tools detection is not generalized
157158
if request.tools:
158159
return ErrorResponse(
159-
message="tools are not supported for the detector",
160-
type="NotImplementedError",
161-
code=HTTPStatus.NOT_IMPLEMENTED.value,
160+
error=ErrorInfo(
161+
message="tools are not supported for the detector",
162+
type="NotImplementedError",
163+
code=HTTPStatus.NOT_IMPLEMENTED.value,
164+
)
162165
)
163166
return request
164167

@@ -241,9 +244,11 @@ async def process_chat_completion_with_scores(
241244
# object would look different, and content would have to be aggregated.
242245
if chat_completion_request.stream:
243246
return ErrorResponse(
244-
message="streaming is not supported for the detector",
245-
type="BadRequestError",
246-
code=HTTPStatus.BAD_REQUEST.value,
247+
error=ErrorInfo(
248+
message="streaming is not supported for the detector",
249+
type="BadRequestError",
250+
code=HTTPStatus.BAD_REQUEST.value,
251+
)
247252
)
248253

249254
# Manually set logprobs to True to calculate score later on
@@ -271,9 +276,11 @@ async def process_chat_completion_with_scores(
271276
# Users _may_ be able to correct some of these errors by changing the input
272277
# but the error message may not be directly user-comprehensible
273278
chat_response = ErrorResponse(
274-
message=e.message or "Template error",
275-
type="BadRequestError",
276-
code=HTTPStatus.BAD_REQUEST.value,
279+
error=ErrorInfo(
280+
message=e.message or "Template error",
281+
type="BadRequestError",
282+
code=HTTPStatus.BAD_REQUEST.value,
283+
)
277284
)
278285

279286
logger.debug("Raw chat completion response: %s", chat_response)
@@ -376,9 +383,11 @@ async def context_analyze(
376383
# Return "not implemented" here since context analysis may not
377384
# generally apply to all models at this time
378385
return ErrorResponse(
379-
message="context analysis is not supported for the detector",
380-
type="NotImplementedError",
381-
code=HTTPStatus.NOT_IMPLEMENTED.value,
386+
error=ErrorInfo(
387+
message="context analysis is not supported for the detector",
388+
type="NotImplementedError",
389+
code=HTTPStatus.NOT_IMPLEMENTED.value,
390+
)
382391
)
383392

384393
async def content_analysis(

0 commit comments

Comments
 (0)