Skip to content

Commit 7678885

Browse files
gkumbhatevaline-ju
andauthored
🐛 Fix llama guard adapter breaking because of metadata for different choice responses (#43)
* 🐛 Fix llama guard adapter breaking because of metadata for different choice responses Signed-off-by: Gaurav-Kumbhat <[email protected]> * Update tests/generative_detectors/test_llama_guard.py Co-authored-by: Evaline Ju <[email protected]> Signed-off-by: Gaurav Kumbhat <[email protected]> --------- Signed-off-by: Gaurav-Kumbhat <[email protected]> Signed-off-by: Gaurav Kumbhat <[email protected]> Co-authored-by: Evaline Ju <[email protected]>
1 parent f6e5d5a commit 7678885

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

tests/generative_detectors/test_llama_guard.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,52 @@ def test_post_process_content_works_for_safe(llama_guard_detection):
231231
assert len(new_response.choices) == 1
232232
assert detection_type == "risk"
233233
# post_process_completion_results function returns array of metadata per choice
234-
assert metadata == []
234+
assert metadata == [{}]
235+
236+
237+
def test_post_process_content_splits_safe_and_unsafe_categories(llama_guard_detection):
238+
safe_message = "safe"
239+
unsafe_message = "\n\nunsafe\nS2,S3"
240+
response = ChatCompletionResponse(
241+
model="foo",
242+
usage=UsageInfo(prompt_tokens=1, total_tokens=1),
243+
choices=[
244+
ChatCompletionResponseChoice(
245+
index=0,
246+
message=ChatMessage(
247+
content=safe_message,
248+
role=" assistant",
249+
),
250+
),
251+
ChatCompletionResponseChoice(
252+
index=1,
253+
message=ChatMessage(
254+
content=unsafe_message,
255+
role=" assistant",
256+
),
257+
),
258+
],
259+
)
260+
261+
expected_metadata = [{}, {"categories": ["Non-Violent Crimes.", "Sex Crimes."]}]
262+
263+
safe_score = 0.6
264+
unsafe_score = 0.99
265+
llama_guard_detection_instance = asyncio.run(llama_guard_detection)
266+
# NOTE: we are testing private function here
267+
(new_response, scores, detection_type, metadata,) = asyncio.run(
268+
llama_guard_detection_instance.post_process_completion_results(
269+
response, [safe_score, unsafe_score], "risk"
270+
)
271+
)
272+
273+
assert isinstance(new_response, ChatCompletionResponse)
274+
assert new_response.choices[1].message.content == "unsafe"
275+
assert scores[1] == unsafe_score
276+
assert len(new_response.choices) == 2
277+
assert detection_type == "risk"
278+
# post_process_completion_results function returns array of metadata per choice
279+
assert metadata == expected_metadata
235280

236281

237282
#### Content detection tests

vllm_detector_adapter/generative_detectors/llama_guard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ async def post_process_completion_results(self, response, scores, detection_type
102102
logger.warning(
103103
f"Category {category} not found in risk bank for model {self.__class__.__name__}"
104104
)
105-
metadata_per_choice.append(metadata)
106105
else:
107106
# "safe" case
108107
new_choices.append(choice)
109108
new_scores.append(scores[i])
109+
metadata_per_choice.append(metadata)
110110

111111
response.choices = new_choices
112112
return response, new_scores, detection_type, metadata_per_choice

0 commit comments

Comments
 (0)