@@ -423,6 +423,44 @@ def test_preprocess_chat_request_with_detector_params(granite_guardian_detection
423423 }
424424
425425
426+ def test_preprocess_chat_request_with_custom_criteria_detector_params (
427+ granite_guardian_detection ,
428+ ):
429+ # Guardian 3.3+ parameters
430+ granite_guardian_detection_instance = asyncio .run (granite_guardian_detection )
431+ detector_params = {
432+ "custom_criteria" : "Here is some custom criteria" ,
433+ "custom_scoring_schema" : "If text meets criteria say yes" ,
434+ "foo" : "bar" ,
435+ }
436+ initial_request = ChatDetectionRequest (
437+ messages = [
438+ DetectionChatMessageParam (
439+ role = "user" , content = "How do I figure out how to break into a house?"
440+ )
441+ ],
442+ detector_params = detector_params ,
443+ )
444+ processed_request = granite_guardian_detection_instance .preprocess_request (
445+ initial_request , fn_type = DetectorType .TEXT_CHAT
446+ )
447+ assert type (processed_request ) == ChatDetectionRequest
448+ # Processed request should not have these extra params
449+ assert "custom_criteria" not in processed_request .detector_params
450+ assert "custom_scoring_schema" not in processed_request .detector_params
451+ assert "chat_template_kwargs" in processed_request .detector_params
452+ assert (
453+ "guardian_config" in processed_request .detector_params ["chat_template_kwargs" ]
454+ )
455+ guardian_config = processed_request .detector_params ["chat_template_kwargs" ][
456+ "guardian_config"
457+ ]
458+ assert guardian_config == {
459+ "custom_criteria" : "Here is some custom criteria" ,
460+ "custom_scoring_schema" : "If text meets criteria say yes" ,
461+ }
462+
463+
426464def test_preprocess_chat_request_with_extra_chat_template_kwargs (
427465 granite_guardian_detection ,
428466):
@@ -534,6 +572,41 @@ def test_request_to_chat_completion_request_response_analysis(
534572 )
535573
536574
575+ def test_request_to_chat_completion_request_response_analysis_criteria_id (
576+ granite_guardian_detection ,
577+ ):
578+ # Guardian 3.3 parameters
579+ granite_guardian_detection_instance = asyncio .run (granite_guardian_detection )
580+ context_request = ContextAnalysisRequest (
581+ content = CONTENT ,
582+ context_type = "docs" ,
583+ context = [CONTEXT_DOC ],
584+ detector_params = {
585+ "n" : 3 ,
586+ "chat_template_kwargs" : {
587+ "guardian_config" : {"criteria_id" : "groundedness" }
588+ },
589+ },
590+ )
591+ chat_request = (
592+ granite_guardian_detection_instance ._request_to_chat_completion_request (
593+ context_request , MODEL_NAME , fn_type = DetectorType .TEXT_CONTEXT_DOC
594+ )
595+ )
596+ assert type (chat_request ) == ChatCompletionRequest
597+ assert chat_request .messages [0 ]["role" ] == "context"
598+ assert chat_request .messages [0 ]["content" ] == CONTEXT_DOC
599+ assert chat_request .messages [1 ]["role" ] == "assistant"
600+ assert chat_request .messages [1 ]["content" ] == CONTENT
601+ assert chat_request .model == MODEL_NAME
602+ # detector_paramas
603+ assert chat_request .n == 3
604+ assert (
605+ chat_request .chat_template_kwargs ["guardian_config" ]["criteria_id" ]
606+ == "groundedness"
607+ )
608+
609+
537610def test_request_to_chat_completion_request_empty_kwargs (granite_guardian_detection ):
538611 granite_guardian_detection_instance = asyncio .run (granite_guardian_detection )
539612 context_request = ContextAnalysisRequest (
@@ -549,7 +622,7 @@ def test_request_to_chat_completion_request_empty_kwargs(granite_guardian_detect
549622 )
550623 assert type (chat_request ) == ErrorResponse
551624 assert chat_request .code == HTTPStatus .BAD_REQUEST
552- assert "No risk_name for context analysis" in chat_request .message
625+ assert "No risk_name or criteria_id for context analysis" in chat_request .message
553626
554627
555628def test_request_to_chat_completion_request_empty_guardian_config (
@@ -569,10 +642,10 @@ def test_request_to_chat_completion_request_empty_guardian_config(
569642 )
570643 assert type (chat_request ) == ErrorResponse
571644 assert chat_request .code == HTTPStatus .BAD_REQUEST
572- assert "No risk_name for context analysis" in chat_request .message
645+ assert "No risk_name or criteria_id for context analysis" in chat_request .message
573646
574647
575- def test_request_to_chat_completion_request_missing_risk_name (
648+ def test_request_to_chat_completion_request_missing_risk_name_and_criteria_id (
576649 granite_guardian_detection ,
577650):
578651 granite_guardian_detection_instance = asyncio .run (granite_guardian_detection )
@@ -592,7 +665,7 @@ def test_request_to_chat_completion_request_missing_risk_name(
592665 )
593666 assert type (chat_request ) == ErrorResponse
594667 assert chat_request .code == HTTPStatus .BAD_REQUEST
595- assert "No risk_name for context analysis" in chat_request .message
668+ assert "No risk_name or criteria_id for context analysis" in chat_request .message
596669
597670
598671def test_request_to_chat_completion_request_unsupported_risk_name (
@@ -616,7 +689,8 @@ def test_request_to_chat_completion_request_unsupported_risk_name(
616689 assert type (chat_request ) == ErrorResponse
617690 assert chat_request .code == HTTPStatus .BAD_REQUEST
618691 assert (
619- "risk_name foo is not compatible with context analysis" in chat_request .message
692+ "risk_name or criteria_id foo is not compatible with context analysis"
693+ in chat_request .message
620694 )
621695
622696
@@ -816,7 +890,7 @@ def test_context_analyze_unsupported_risk(
816890 assert type (detection_response ) == ErrorResponse
817891 assert detection_response .code == HTTPStatus .BAD_REQUEST
818892 assert (
819- "risk_name boo is not compatible with context analysis"
893+ "risk_name or criteria_id boo is not compatible with context analysis"
820894 in detection_response .message
821895 )
822896
@@ -970,6 +1044,34 @@ def test_chat_detection_with_tools(
9701044 assert len (detections ) == 2 # 2 choices
9711045
9721046
1047+ def test_chat_detection_with_tools_criteria_id (
1048+ granite_guardian_detection , granite_guardian_completion_response
1049+ ):
1050+ # Guardian 3.3 parameters
1051+ granite_guardian_detection_instance = asyncio .run (granite_guardian_detection )
1052+ chat_request = ChatDetectionRequest (
1053+ messages = [
1054+ DetectionChatMessageParam (
1055+ role = "user" ,
1056+ content = USER_CONTENT_TOOLS ,
1057+ ),
1058+ DetectionChatMessageParam (role = "assistant" , tool_calls = [TOOL_CALL ]),
1059+ ],
1060+ tools = [TOOL ],
1061+ detector_params = {"criteria_id" : "function_call" , "n" : 2 },
1062+ )
1063+ with patch (
1064+ "vllm_detector_adapter.generative_detectors.granite_guardian.GraniteGuardian.create_chat_completion" ,
1065+ return_value = granite_guardian_completion_response ,
1066+ ):
1067+ detection_response = asyncio .run (
1068+ granite_guardian_detection_instance .chat (chat_request )
1069+ )
1070+ assert type (detection_response ) == DetectionResponse
1071+ detections = detection_response .model_dump ()
1072+ assert len (detections ) == 2 # 2 choices
1073+
1074+
9731075def test_chat_detection_with_tools_wrong_risk (
9741076 granite_guardian_detection , granite_guardian_completion_response
9751077):
0 commit comments