Skip to content

Commit 55a5dd0

Browse files
authored
✨ Support Granite Guardian 3.3 criteria parameters (#87)
* ✨✅ Add GG 3.3+ criteria parameters Signed-off-by: Evaline Ju <[email protected]> * ✅ Add criteria parameters tests Signed-off-by: Evaline Ju <[email protected]> * 📌 Pin transformers upper bound Signed-off-by: Evaline Ju <[email protected]> * ✅ Update test name with custom Signed-off-by: Evaline Ju <[email protected]> --------- Signed-off-by: Evaline Ju <[email protected]>
1 parent 6e191cb commit 55a5dd0

File tree

3 files changed

+161
-18
lines changed

3 files changed

+161
-18
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ dependencies = ["orjson>=3.10.16,<3.11"]
1515
[project.optional-dependencies]
1616
vllm-tgis-adapter = ["vllm-tgis-adapter>=0.7.0,<0.7.2"]
1717
vllm = [
18-
# Note: 0.8.4 has a triton bug on Mac
18+
"transformers<4.54.0", # vllm <= 0.10.0 has issues with higher transformers versions, fixed later in https://github.com/vllm-project/vllm/pull/20921
1919
"vllm @ git+https://github.com/vllm-project/[email protected] ; sys_platform == 'darwin'",
2020
"vllm>=0.7.2,<0.9.1 ; sys_platform != 'darwin'",
2121
]

tests/generative_detectors/test_granite_guardian.py

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,44 @@ def test_preprocess_chat_request_with_detector_params(granite_guardian_detection
423423
}
424424

425425

426+
def test_preprocess_chat_request_with_custom_criteria_detector_params(
427+
granite_guardian_detection,
428+
):
429+
# Guardian 3.3+ parameters
430+
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
431+
detector_params = {
432+
"custom_criteria": "Here is some custom criteria",
433+
"custom_scoring_schema": "If text meets criteria say yes",
434+
"foo": "bar",
435+
}
436+
initial_request = ChatDetectionRequest(
437+
messages=[
438+
DetectionChatMessageParam(
439+
role="user", content="How do I figure out how to break into a house?"
440+
)
441+
],
442+
detector_params=detector_params,
443+
)
444+
processed_request = granite_guardian_detection_instance.preprocess_request(
445+
initial_request, fn_type=DetectorType.TEXT_CHAT
446+
)
447+
assert type(processed_request) == ChatDetectionRequest
448+
# Processed request should not have these extra params
449+
assert "custom_criteria" not in processed_request.detector_params
450+
assert "custom_scoring_schema" not in processed_request.detector_params
451+
assert "chat_template_kwargs" in processed_request.detector_params
452+
assert (
453+
"guardian_config" in processed_request.detector_params["chat_template_kwargs"]
454+
)
455+
guardian_config = processed_request.detector_params["chat_template_kwargs"][
456+
"guardian_config"
457+
]
458+
assert guardian_config == {
459+
"custom_criteria": "Here is some custom criteria",
460+
"custom_scoring_schema": "If text meets criteria say yes",
461+
}
462+
463+
426464
def test_preprocess_chat_request_with_extra_chat_template_kwargs(
427465
granite_guardian_detection,
428466
):
@@ -534,6 +572,41 @@ def test_request_to_chat_completion_request_response_analysis(
534572
)
535573

536574

575+
def test_request_to_chat_completion_request_response_analysis_criteria_id(
576+
granite_guardian_detection,
577+
):
578+
# Guardian 3.3 parameters
579+
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
580+
context_request = ContextAnalysisRequest(
581+
content=CONTENT,
582+
context_type="docs",
583+
context=[CONTEXT_DOC],
584+
detector_params={
585+
"n": 3,
586+
"chat_template_kwargs": {
587+
"guardian_config": {"criteria_id": "groundedness"}
588+
},
589+
},
590+
)
591+
chat_request = (
592+
granite_guardian_detection_instance._request_to_chat_completion_request(
593+
context_request, MODEL_NAME, fn_type=DetectorType.TEXT_CONTEXT_DOC
594+
)
595+
)
596+
assert type(chat_request) == ChatCompletionRequest
597+
assert chat_request.messages[0]["role"] == "context"
598+
assert chat_request.messages[0]["content"] == CONTEXT_DOC
599+
assert chat_request.messages[1]["role"] == "assistant"
600+
assert chat_request.messages[1]["content"] == CONTENT
601+
assert chat_request.model == MODEL_NAME
602+
# detector_paramas
603+
assert chat_request.n == 3
604+
assert (
605+
chat_request.chat_template_kwargs["guardian_config"]["criteria_id"]
606+
== "groundedness"
607+
)
608+
609+
537610
def test_request_to_chat_completion_request_empty_kwargs(granite_guardian_detection):
538611
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
539612
context_request = ContextAnalysisRequest(
@@ -549,7 +622,7 @@ def test_request_to_chat_completion_request_empty_kwargs(granite_guardian_detect
549622
)
550623
assert type(chat_request) == ErrorResponse
551624
assert chat_request.code == HTTPStatus.BAD_REQUEST
552-
assert "No risk_name for context analysis" in chat_request.message
625+
assert "No risk_name or criteria_id for context analysis" in chat_request.message
553626

554627

555628
def test_request_to_chat_completion_request_empty_guardian_config(
@@ -569,10 +642,10 @@ def test_request_to_chat_completion_request_empty_guardian_config(
569642
)
570643
assert type(chat_request) == ErrorResponse
571644
assert chat_request.code == HTTPStatus.BAD_REQUEST
572-
assert "No risk_name for context analysis" in chat_request.message
645+
assert "No risk_name or criteria_id for context analysis" in chat_request.message
573646

574647

575-
def test_request_to_chat_completion_request_missing_risk_name(
648+
def test_request_to_chat_completion_request_missing_risk_name_and_criteria_id(
576649
granite_guardian_detection,
577650
):
578651
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
@@ -592,7 +665,7 @@ def test_request_to_chat_completion_request_missing_risk_name(
592665
)
593666
assert type(chat_request) == ErrorResponse
594667
assert chat_request.code == HTTPStatus.BAD_REQUEST
595-
assert "No risk_name for context analysis" in chat_request.message
668+
assert "No risk_name or criteria_id for context analysis" in chat_request.message
596669

597670

598671
def test_request_to_chat_completion_request_unsupported_risk_name(
@@ -616,7 +689,8 @@ def test_request_to_chat_completion_request_unsupported_risk_name(
616689
assert type(chat_request) == ErrorResponse
617690
assert chat_request.code == HTTPStatus.BAD_REQUEST
618691
assert (
619-
"risk_name foo is not compatible with context analysis" in chat_request.message
692+
"risk_name or criteria_id foo is not compatible with context analysis"
693+
in chat_request.message
620694
)
621695

622696

@@ -816,7 +890,7 @@ def test_context_analyze_unsupported_risk(
816890
assert type(detection_response) == ErrorResponse
817891
assert detection_response.code == HTTPStatus.BAD_REQUEST
818892
assert (
819-
"risk_name boo is not compatible with context analysis"
893+
"risk_name or criteria_id boo is not compatible with context analysis"
820894
in detection_response.message
821895
)
822896

@@ -970,6 +1044,34 @@ def test_chat_detection_with_tools(
9701044
assert len(detections) == 2 # 2 choices
9711045

9721046

1047+
def test_chat_detection_with_tools_criteria_id(
1048+
granite_guardian_detection, granite_guardian_completion_response
1049+
):
1050+
# Guardian 3.3 parameters
1051+
granite_guardian_detection_instance = asyncio.run(granite_guardian_detection)
1052+
chat_request = ChatDetectionRequest(
1053+
messages=[
1054+
DetectionChatMessageParam(
1055+
role="user",
1056+
content=USER_CONTENT_TOOLS,
1057+
),
1058+
DetectionChatMessageParam(role="assistant", tool_calls=[TOOL_CALL]),
1059+
],
1060+
tools=[TOOL],
1061+
detector_params={"criteria_id": "function_call", "n": 2},
1062+
)
1063+
with patch(
1064+
"vllm_detector_adapter.generative_detectors.granite_guardian.GraniteGuardian.create_chat_completion",
1065+
return_value=granite_guardian_completion_response,
1066+
):
1067+
detection_response = asyncio.run(
1068+
granite_guardian_detection_instance.chat(chat_request)
1069+
)
1070+
assert type(detection_response) == DetectionResponse
1071+
detections = detection_response.model_dump()
1072+
assert len(detections) == 2 # 2 choices
1073+
1074+
9731075
def test_chat_detection_with_tools_wrong_risk(
9741076
granite_guardian_detection, granite_guardian_completion_response
9751077
):

vllm_detector_adapter/generative_detectors/granite_guardian.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ class GraniteGuardianToolCallFunctionObject(TypedDict):
4040

4141

4242
class GraniteGuardian(ChatCompletionDetectionBase):
43-
43+
# Note: Earlier generations of Granite Guardian use 'risk' while Granite Guardian
44+
# 3.3 refers to 'criteria' for generalization. For now, because the taxonomy is
45+
# still characterized as a 'risk' taxonomy, the detection type remains.
4446
DETECTION_TYPE = "risk"
4547
# User text pattern in task template
4648
USER_TEXT_PATTERN = "user_text"
@@ -62,6 +64,7 @@ class GraniteGuardian(ChatCompletionDetectionBase):
6264
INDENT = orjson.OPT_INDENT_2
6365

6466
# Risk Bank name defined in the chat template
67+
# Not actively used ref. https://github.com/foundation-model-stack/vllm-detector-adapter/issues/64
6568
RISK_BANK_VAR_NAME = "risk_bank"
6669

6770
# Attributes to be put in metadata
@@ -87,21 +90,31 @@ def __preprocess(
8790
GenerationDetectionRequest,
8891
ErrorResponse,
8992
]:
90-
"""Granite guardian specific parameter updates for risk name and risk definition"""
93+
"""Granite guardian specific parameter updates for risks/criteria"""
9194
# Validation that one of the 'defined' risks is requested will be
9295
# done through the chat template on each request. Errors will
9396
# be propagated for chat completion separately
9497
guardian_config = {}
9598
if not request.detector_params:
9699
return request
97100

101+
# Guardian 3.2 and earlier
98102
if risk_name := request.detector_params.pop("risk_name", None):
99103
guardian_config["risk_name"] = risk_name
100104
if risk_definition := request.detector_params.pop("risk_definition", None):
101105
guardian_config["risk_definition"] = risk_definition
106+
# Guardian 3.3+
107+
if criteria_id := request.detector_params.pop("criteria_id", None):
108+
guardian_config["criteria_id"] = criteria_id
109+
if custom_criteria := request.detector_params.pop("custom_criteria", None):
110+
guardian_config["custom_criteria"] = custom_criteria
111+
if custom_scoring_schema := request.detector_params.pop(
112+
"custom_scoring_schema", None
113+
):
114+
guardian_config["custom_scoring_schema"] = custom_scoring_schema
102115
if guardian_config:
103116
logger.debug("guardian_config {} provided for request", guardian_config)
104-
# Move the risk name and/or risk definition to chat_template_kwargs
117+
# Move the parameters to chat_template_kwargs
105118
# to be propagated to tokenizer.apply_chat_template during
106119
# chat completion
107120
if "chat_template_kwargs" in request.detector_params:
@@ -134,15 +147,33 @@ def _make_tools_request(
134147

135148
if (
136149
"risk_name" not in request.detector_params
137-
or request.detector_params["risk_name"] not in self.TOOLS_RISKS
150+
and "criteria_id" not in request.detector_params
151+
):
152+
return ErrorResponse(
153+
message="tools analysis is not supported without a given risk/criteria",
154+
type="BadRequestError",
155+
code=HTTPStatus.BAD_REQUEST.value,
156+
)
157+
# Granite 3.2 and earlier
158+
if (
159+
"risk_name" in request.detector_params
160+
and request.detector_params["risk_name"] not in self.TOOLS_RISKS
138161
):
139-
# Provide error here, since otherwise follow-on tools message
140-
# and assistant message flattening will not be applicable
141162
return ErrorResponse(
142163
message="tools analysis is not supported with given risk",
143164
type="BadRequestError",
144165
code=HTTPStatus.BAD_REQUEST.value,
145166
)
167+
# Granite 3.3+
168+
elif (
169+
"criteria_id" in request.detector_params
170+
and request.detector_params["criteria_id"] not in self.TOOLS_RISKS
171+
):
172+
return ErrorResponse(
173+
message="tools analysis is not supported with given criteria",
174+
type="BadRequestError",
175+
code=HTTPStatus.BAD_REQUEST.value,
176+
)
146177

147178
# (1) 'Flatten' the assistant message, extracting the functions in the tool_calls
148179
# portion of the message
@@ -242,7 +273,7 @@ def _make_tools_request(
242273
def _request_to_chat_completion_request(
243274
self, request: ContextAnalysisRequest, model_name: str
244275
) -> Union[ChatCompletionRequest, ErrorResponse]:
245-
NO_RISK_NAME_MESSAGE = "No risk_name for context analysis"
276+
NO_RISK_NAME_MESSAGE = "No risk_name or criteria_id for context analysis"
246277

247278
risk_name = None
248279
if (
@@ -259,8 +290,10 @@ def _request_to_chat_completion_request(
259290
"guardian_config"
260291
]:
261292
if isinstance(guardian_config, dict):
262-
risk_name = guardian_config.get("risk_name")
263-
# Leaving off risk name can lead to model/template errors
293+
risk_name = guardian_config.get("risk_name") or guardian_config.get(
294+
"criteria_id"
295+
)
296+
# Leaving off risk_name and criteria_id can lead to model/template errors
264297
if not risk_name:
265298
return ErrorResponse(
266299
message=NO_RISK_NAME_MESSAGE,
@@ -292,9 +325,9 @@ def _request_to_chat_completion_request(
292325
{"role": "user", "content": content},
293326
]
294327
else:
295-
# Return error if risk names are not expected ones
328+
# Return error if risk names or criteria are not expected ones
296329
return ErrorResponse(
297-
message="risk_name {} is not compatible with context analysis".format(
330+
message="risk_name or criteria_id {} is not compatible with context analysis".format(
298331
risk_name
299332
),
300333
type="BadRequestError",
@@ -448,10 +481,18 @@ async def generation_analyze(
448481

449482
# If risk_name is not specifically provided for this endpoint, we will add a
450483
# risk_name, since the user has already decided to use this particular endpoint
484+
# Granite Guardian 3.2 and earlier
451485
if "risk_name" not in request.detector_params:
452486
request.detector_params[
453487
"risk_name"
454488
] = self.DEFAULT_GENERATION_DETECTION_RISK
489+
# Granite Guardian 3.3+
490+
# Generally the additional/repeated risk is not problematic
491+
# This avoids having to verify Guardian version at this step
492+
if "criteria_id" not in request.detector_params:
493+
request.detector_params[
494+
"criteria_id"
495+
] = self.DEFAULT_GENERATION_DETECTION_RISK
455496

456497
# Task template not applied for generation analysis at this time
457498
# Make model-dependent adjustments for the request

0 commit comments

Comments
 (0)