Skip to content

Commit 54050ae

Browse files
authored
Merge pull request #21 from gkumbhat/add_content_detector
Add content detector
2 parents 98295e4 + 308f2e9 commit 54050ae

File tree

8 files changed

+652
-12
lines changed

8 files changed

+652
-12
lines changed

tests/generative_detectors/test_base.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,31 @@
11
# Standard
22
from dataclasses import dataclass
33
from typing import Optional
4+
from unittest.mock import patch
45
import asyncio
56

67
# Third Party
78
from vllm.config import MultiModalConfig
9+
from vllm.entrypoints.openai.protocol import (
10+
ChatCompletionLogProb,
11+
ChatCompletionLogProbs,
12+
ChatCompletionLogProbsContent,
13+
ChatCompletionResponse,
14+
ChatCompletionResponseChoice,
15+
ChatMessage,
16+
UsageInfo,
17+
)
818
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
919
import jinja2
20+
import pytest
1021
import pytest_asyncio
1122

1223
# Local
1324
from vllm_detector_adapter.generative_detectors.base import ChatCompletionDetectionBase
25+
from vllm_detector_adapter.protocol import (
26+
ContentsDetectionRequest,
27+
ContentsDetectionResponse,
28+
)
1429

1530
MODEL_NAME = "openai-community/gpt2"
1631
CHAT_TEMPLATE = "Dummy chat template for testing {}"
@@ -82,6 +97,55 @@ async def detection_base():
8297
return _async_serving_detection_completion_init()
8398

8499

100+
@pytest.fixture(scope="module")
101+
def completion_response():
102+
log_probs_content_no = ChatCompletionLogProbsContent(
103+
token="no",
104+
logprob=-0.0013,
105+
# 5 logprobs requested for scoring, skipping bytes for conciseness
106+
top_logprobs=[
107+
ChatCompletionLogProb(token="no", logprob=-0.053),
108+
ChatCompletionLogProb(token="0", logprob=-6.61),
109+
ChatCompletionLogProb(token="1", logprob=-16.90),
110+
ChatCompletionLogProb(token="2", logprob=-17.39),
111+
ChatCompletionLogProb(token="3", logprob=-17.61),
112+
],
113+
)
114+
log_probs_content_yes = ChatCompletionLogProbsContent(
115+
token="yes",
116+
logprob=-0.0013,
117+
# 5 logprobs requested for scoring, skipping bytes for conciseness
118+
top_logprobs=[
119+
ChatCompletionLogProb(token="yes", logprob=-0.0013),
120+
ChatCompletionLogProb(token="0", logprob=-6.61),
121+
ChatCompletionLogProb(token="1", logprob=-16.90),
122+
ChatCompletionLogProb(token="2", logprob=-17.39),
123+
ChatCompletionLogProb(token="3", logprob=-17.61),
124+
],
125+
)
126+
choice_0 = ChatCompletionResponseChoice(
127+
index=0,
128+
message=ChatMessage(
129+
role="assistant",
130+
content="no",
131+
),
132+
logprobs=ChatCompletionLogProbs(content=[log_probs_content_no]),
133+
)
134+
choice_1 = ChatCompletionResponseChoice(
135+
index=1,
136+
message=ChatMessage(
137+
role="assistant",
138+
content="yes",
139+
),
140+
logprobs=ChatCompletionLogProbs(content=[log_probs_content_yes]),
141+
)
142+
yield ChatCompletionResponse(
143+
model=MODEL_NAME,
144+
choices=[choice_0, choice_1],
145+
usage=UsageInfo(prompt_tokens=200, total_tokens=206, completion_tokens=6),
146+
)
147+
148+
85149
### Tests #####################################################################
86150

87151

@@ -97,3 +161,38 @@ def test_async_serving_detection_completion_init(detection_base):
97161
output_template = detection_completion.output_template
98162
assert type(output_template) == jinja2.environment.Template
99163
assert output_template.render(({"text": "moose"})) == "bye moose"
164+
165+
166+
def test_content_analysis_success(detection_base, completion_response):
167+
base_instance = asyncio.run(detection_base)
168+
169+
content_request = ContentsDetectionRequest(
170+
contents=["Where do I find geese?", "You could go to Canada"]
171+
)
172+
173+
scores = [0.9, 0.1]
174+
response = (completion_response, scores, "risk")
175+
with patch(
176+
"vllm_detector_adapter.generative_detectors.base.ChatCompletionDetectionBase.process_chat_completion_with_scores",
177+
return_value=response,
178+
):
179+
result = asyncio.run(base_instance.content_analysis(content_request))
180+
assert isinstance(result, ContentsDetectionResponse)
181+
detections = result.model_dump()
182+
assert len(detections) == 2
183+
# For first content
184+
assert detections[0][0]["detection"] == "no"
185+
assert detections[0][0]["score"] == 0.9
186+
assert detections[0][0]["start"] == 0
187+
assert detections[0][0]["end"] == len(content_request.contents[0])
188+
# 2nd choice as 2nd label
189+
assert detections[0][1]["detection"] == "yes"
190+
assert detections[0][1]["score"] == 0.1
191+
assert detections[0][1]["start"] == 0
192+
assert detections[0][1]["end"] == len(content_request.contents[0])
193+
# For 2nd content, we are only testing 1st detection for simplicity
194+
# Note: detection is same, because of how mock is working.
195+
assert detections[1][0]["detection"] == "no"
196+
assert detections[1][0]["score"] == 0.9
197+
assert detections[1][0]["start"] == 0
198+
assert detections[1][0]["end"] == len(content_request.contents[1])

tests/generative_detectors/test_llama_guard.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from vllm_detector_adapter.generative_detectors.llama_guard import LlamaGuard
2626
from vllm_detector_adapter.protocol import (
2727
ChatDetectionRequest,
28+
ContentsDetectionRequest,
29+
ContentsDetectionResponse,
2830
ContextAnalysisRequest,
2931
DetectionChatMessageParam,
3032
DetectionResponse,
@@ -213,3 +215,90 @@ def test_context_analyze(llama_guard_detection):
213215
)
214216
assert type(response) == ErrorResponse
215217
assert response.code == HTTPStatus.NOT_IMPLEMENTED
218+
219+
220+
def test_post_process_content_splits_unsafe_categories(llama_guard_detection):
221+
unsafe_message = "\n\nunsafe\nS2,S3"
222+
response = ChatCompletionResponse(
223+
model="foo",
224+
usage=UsageInfo(prompt_tokens=1, total_tokens=1),
225+
choices=[
226+
ChatCompletionResponseChoice(
227+
index=1,
228+
message=ChatMessage(
229+
content=unsafe_message,
230+
role=" assistant",
231+
),
232+
)
233+
],
234+
)
235+
unsafe_score = 0.99
236+
llama_guard_detection_instance = asyncio.run(llama_guard_detection)
237+
# NOTE: we are testing private function here
238+
(
239+
response,
240+
scores,
241+
_,
242+
) = llama_guard_detection_instance._LlamaGuard__post_process_result(
243+
response, [unsafe_score], "risk"
244+
)
245+
assert isinstance(response, ChatCompletionResponse)
246+
assert response.choices[0].message.content == "unsafe"
247+
assert scores[0] == unsafe_score
248+
assert len(response.choices) == 1
249+
250+
251+
def test_post_process_content_works_for_safe(llama_guard_detection):
252+
safe_message = "safe"
253+
response = ChatCompletionResponse(
254+
model="foo",
255+
usage=UsageInfo(prompt_tokens=1, total_tokens=1),
256+
choices=[
257+
ChatCompletionResponseChoice(
258+
index=1,
259+
message=ChatMessage(
260+
content=safe_message,
261+
role=" assistant",
262+
),
263+
)
264+
],
265+
)
266+
safe_score = 0.99
267+
llama_guard_detection_instance = asyncio.run(llama_guard_detection)
268+
# NOTE: we are testing private function here
269+
(
270+
response,
271+
scores,
272+
_,
273+
) = llama_guard_detection_instance._LlamaGuard__post_process_result(
274+
response, [safe_score], "risk"
275+
)
276+
277+
assert isinstance(response, ChatCompletionResponse)
278+
assert len(response.choices) == 1
279+
assert response.choices[0].message.content == "safe"
280+
assert scores[0] == safe_score
281+
282+
283+
def test_content_detection_with_llama_guard(
284+
llama_guard_detection, llama_guard_completion_response
285+
):
286+
llama_guard_detection_instance = asyncio.run(llama_guard_detection)
287+
content_request = ContentsDetectionRequest(
288+
contents=["Where do I find geese?", "You could go to Canada"]
289+
)
290+
with patch(
291+
"vllm_detector_adapter.generative_detectors.llama_guard.LlamaGuard.create_chat_completion",
292+
return_value=llama_guard_completion_response,
293+
):
294+
detection_response = asyncio.run(
295+
llama_guard_detection_instance.content_analysis(content_request)
296+
)
297+
assert type(detection_response) == ContentsDetectionResponse
298+
detections = detection_response.model_dump()
299+
assert len(detections) == 2 # 2 contents in the request
300+
assert len(detections[0]) == 2 # 2 choices
301+
detection_0 = detections[0][0] # for 1st text in request
302+
assert detection_0["detection"] == "safe"
303+
assert detection_0["detection_type"] == "risk"
304+
assert pytest.approx(detection_0["score"]) == 0.001346767

tests/test_protocol.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# Local
1515
from vllm_detector_adapter.protocol import (
1616
ChatDetectionRequest,
17+
ContentsDetectionResponse,
18+
ContentsDetectionResponseObject,
1719
DetectionChatMessageParam,
1820
DetectionResponse,
1921
)
@@ -129,3 +131,151 @@ def test_response_from_completion_response_missing_content():
129131
in detection_response.message
130132
)
131133
assert detection_response.code == HTTPStatus.BAD_REQUEST.value
134+
135+
136+
def test_response_from_single_content_detection_response():
137+
choice = ChatCompletionResponseChoice(
138+
index=0,
139+
message=ChatMessage(
140+
role="assistant",
141+
content=" moose",
142+
),
143+
)
144+
chat_response = ChatCompletionResponse(
145+
model=MODEL_NAME,
146+
choices=[choice],
147+
usage=UsageInfo(prompt_tokens=136, total_tokens=140, completion_tokens=4),
148+
)
149+
contents = ["sample sentence"]
150+
scores = [0.9]
151+
detection_type = "risk"
152+
153+
expected_response = ContentsDetectionResponse(
154+
root=[
155+
[
156+
ContentsDetectionResponseObject(
157+
start=0,
158+
end=len(contents[0]),
159+
score=scores[0],
160+
text=contents[0],
161+
detection="moose",
162+
detection_type=detection_type,
163+
)
164+
]
165+
]
166+
)
167+
detection_response = ContentsDetectionResponse.from_chat_completion_response(
168+
[(chat_response, scores, detection_type)], contents
169+
)
170+
assert isinstance(detection_response, ContentsDetectionResponse)
171+
assert detection_response == expected_response
172+
173+
174+
def test_response_from_multi_contents_detection_response():
175+
choice_content_0 = ChatCompletionResponseChoice(
176+
index=0,
177+
message=ChatMessage(
178+
role="assistant",
179+
content=" moose",
180+
),
181+
)
182+
choice_content_1 = ChatCompletionResponseChoice(
183+
index=0,
184+
message=ChatMessage(
185+
role="assistant",
186+
content=" goose",
187+
),
188+
)
189+
chat_response_0 = ChatCompletionResponse(
190+
model=MODEL_NAME,
191+
choices=[choice_content_0],
192+
usage=UsageInfo(prompt_tokens=136, total_tokens=140, completion_tokens=4),
193+
)
194+
chat_response_1 = ChatCompletionResponse(
195+
model=MODEL_NAME,
196+
choices=[choice_content_1],
197+
usage=UsageInfo(prompt_tokens=136, total_tokens=140, completion_tokens=4),
198+
)
199+
200+
contents = ["sample sentence 1", "sample sentence 2"]
201+
# scores for each content is a list of scores (for multi-label)
202+
scores = [[0.9], [0.6]]
203+
detection_type = "risk"
204+
205+
content_response_0 = [
206+
ContentsDetectionResponseObject(
207+
start=0,
208+
end=len(contents[0]),
209+
score=scores[0][0],
210+
text=contents[0],
211+
detection="moose",
212+
detection_type=detection_type,
213+
)
214+
]
215+
content_response_1 = [
216+
ContentsDetectionResponseObject(
217+
start=0,
218+
end=len(contents[1]),
219+
score=scores[1][0],
220+
text=contents[1],
221+
detection="goose",
222+
detection_type=detection_type,
223+
)
224+
]
225+
expected_response = ContentsDetectionResponse(
226+
root=[content_response_0, content_response_1]
227+
)
228+
detection_response = ContentsDetectionResponse.from_chat_completion_response(
229+
[
230+
(chat_response_0, scores[0], detection_type),
231+
(chat_response_1, scores[1], detection_type),
232+
],
233+
contents,
234+
)
235+
assert isinstance(detection_response, ContentsDetectionResponse)
236+
assert detection_response == expected_response
237+
238+
239+
def test_response_from_single_content_detection_missing_content():
240+
choice_content_0 = ChatCompletionResponseChoice(
241+
index=0,
242+
message=ChatMessage(
243+
role="assistant",
244+
),
245+
)
246+
choice_content_1 = ChatCompletionResponseChoice(
247+
index=0,
248+
message=ChatMessage(
249+
role="assistant",
250+
content=" goose",
251+
),
252+
)
253+
chat_response_0 = ChatCompletionResponse(
254+
model=MODEL_NAME,
255+
choices=[choice_content_0],
256+
usage=UsageInfo(prompt_tokens=136, total_tokens=140, completion_tokens=4),
257+
)
258+
chat_response_1 = ChatCompletionResponse(
259+
model=MODEL_NAME,
260+
choices=[choice_content_1],
261+
usage=UsageInfo(prompt_tokens=136, total_tokens=140, completion_tokens=4),
262+
)
263+
264+
contents = ["sample sentence 1", "sample sentence 2"]
265+
# scores for each content is a list of scores (for multi-label)
266+
scores = [[0.9], [0.6]]
267+
detection_type = "risk"
268+
269+
detection_response = ContentsDetectionResponse.from_chat_completion_response(
270+
[
271+
(chat_response_0, scores[0], detection_type),
272+
(chat_response_1, scores[1], detection_type),
273+
],
274+
contents,
275+
)
276+
assert type(detection_response) == ErrorResponse
277+
assert (
278+
"Choice 0 from chat completion does not have content"
279+
in detection_response.message
280+
)
281+
assert detection_response.code == HTTPStatus.BAD_REQUEST.value

0 commit comments

Comments
 (0)