Skip to content

Commit 308f2e9

Browse files
committed
🎨✅ Fix doc strings and scores in llama-test
Signed-off-by: Gaurav-Kumbhat <[email protected]>
1 parent 199de11 commit 308f2e9

File tree

5 files changed

+26
-25
lines changed

5 files changed

+26
-25
lines changed

tests/generative_detectors/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def test_content_analysis_success(detection_base, completion_response):
170170
contents=["Where do I find geese?", "You could go to Canada"]
171171
)
172172

173-
scores = [0.9, 0.1, 0.21, 0.54, 0.33]
173+
scores = [0.9, 0.1]
174174
response = (completion_response, scores, "risk")
175175
with patch(
176176
"vllm_detector_adapter.generative_detectors.base.ChatCompletionDetectionBase.process_chat_completion_with_scores",

tests/generative_detectors/test_llama_guard.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def test_context_analyze(llama_guard_detection):
219219

220220
def test_post_process_content_splits_unsafe_categories(llama_guard_detection):
221221
unsafe_message = "\n\nunsafe\nS2,S3"
222-
responses = ChatCompletionResponse(
222+
response = ChatCompletionResponse(
223223
model="foo",
224224
usage=UsageInfo(prompt_tokens=1, total_tokens=1),
225225
choices=[
@@ -236,21 +236,21 @@ def test_post_process_content_splits_unsafe_categories(llama_guard_detection):
236236
llama_guard_detection_instance = asyncio.run(llama_guard_detection)
237237
# NOTE: we are testing private function here
238238
(
239-
responses,
239+
response,
240240
scores,
241241
_,
242242
) = llama_guard_detection_instance._LlamaGuard__post_process_result(
243-
responses, [unsafe_score], "risk"
243+
response, [unsafe_score], "risk"
244244
)
245-
assert isinstance(responses, ChatCompletionResponse)
246-
assert responses.choices[0].message.content == "unsafe"
245+
assert isinstance(response, ChatCompletionResponse)
246+
assert response.choices[0].message.content == "unsafe"
247247
assert scores[0] == unsafe_score
248-
assert len(responses.choices) == 1
248+
assert len(response.choices) == 1
249249

250250

251251
def test_post_process_content_works_for_safe(llama_guard_detection):
252252
safe_message = "safe"
253-
responses = ChatCompletionResponse(
253+
response = ChatCompletionResponse(
254254
model="foo",
255255
usage=UsageInfo(prompt_tokens=1, total_tokens=1),
256256
choices=[
@@ -267,16 +267,17 @@ def test_post_process_content_works_for_safe(llama_guard_detection):
267267
llama_guard_detection_instance = asyncio.run(llama_guard_detection)
268268
# NOTE: we are testing private function here
269269
(
270-
responses,
270+
response,
271271
scores,
272272
_,
273273
) = llama_guard_detection_instance._LlamaGuard__post_process_result(
274-
responses, [safe_message], "risk"
274+
response, [safe_score], "risk"
275275
)
276-
assert isinstance(responses, ChatCompletionResponse)
277-
assert len(responses.choices) == 1
278-
assert responses.choices[0].message.content == "safe"
279-
assert scores[0] == safe_message
276+
277+
assert isinstance(response, ChatCompletionResponse)
278+
assert len(response.choices) == 1
279+
assert response.choices[0].message.content == "safe"
280+
assert scores[0] == safe_score
280281

281282

282283
def test_content_detection_with_llama_guard(

vllm_detector_adapter/generative_detectors/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Standard
22
from http import HTTPStatus
33
from pathlib import Path
4-
from typing import List, Optional, Union
4+
from typing import List, Optional, Tuple, Union
55
import asyncio
66
import codecs
77
import math
@@ -173,7 +173,7 @@ def calculate_scores(self, response: ChatCompletionResponse) -> List[float]:
173173

174174
async def process_chat_completion_with_scores(
175175
self, chat_completion_request, raw_request
176-
) -> Union[DetectionResponse, ErrorResponse]:
176+
) -> Union[Tuple[ChatCompletionResponse, List[float], str], ErrorResponse]:
177177
# Return an error for streaming for now. Since the detector API is unary,
178178
# results would not be streamed back anyway. The chat completion response
179179
# object would look different, and content would have to be aggregated.

vllm_detector_adapter/generative_detectors/llama_guard.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,16 @@ class LlamaGuard(ChatCompletionDetectionBase):
2626
SAFE_TOKEN = "safe"
2727
UNSAFE_TOKEN = "unsafe"
2828

29-
def __post_process_result(self, responses, scores, detection_type):
29+
def __post_process_result(self, response, scores, detection_type):
3030
"""Function to process chat completion results for content type detection.
3131
3232
Args:
33-
responses: ChatCompletionResponse,
33+
response: ChatCompletionResponse,
3434
scores: List[float],
3535
detection_type: str,
3636
Returns:
3737
Tuple(
38-
responses: ChatCompletionResponse,
38+
response: ChatCompletionResponse,
3939
scores: List[float],
4040
detection_type,
4141
)
@@ -51,7 +51,7 @@ def __post_process_result(self, responses, scores, detection_type):
5151
new_scores = []
5252

5353
# NOTE: we are flattening out choices here as different categories
54-
for i, choice in enumerate(responses.choices):
54+
for i, choice in enumerate(response.choices):
5555
content = choice.message.content
5656
if self.UNSAFE_TOKEN in content:
5757
# Reason for reassigning the content:
@@ -64,8 +64,8 @@ def __post_process_result(self, responses, scores, detection_type):
6464
new_choices.append(choice)
6565
new_scores.append(scores[i])
6666

67-
responses.choices = new_choices
68-
return (responses, new_scores, detection_type)
67+
response.choices = new_choices
68+
return (response, new_scores, detection_type)
6969

7070
async def content_analysis(
7171
self,

vllm_detector_adapter/protocol.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def from_chat_completion_response(results, contents: List[str]):
5353
5454
Args:
5555
results: List(Tuple(
56-
responses: ChatCompletionResponse,
56+
response: ChatCompletionResponse,
5757
scores: List[float],
5858
detection_type,
5959
))
@@ -62,13 +62,13 @@ def from_chat_completion_response(results, contents: List[str]):
6262
"""
6363
contents_detection_responses = []
6464

65-
for content_idx, (responses, scores, detection_type) in enumerate(results):
65+
for content_idx, (response, scores, detection_type) in enumerate(results):
6666

6767
detection_responses = []
6868
start = 0
6969
end = len(contents[content_idx])
7070

71-
for i, choice in enumerate(responses.choices):
71+
for i, choice in enumerate(response.choices):
7272
content = choice.message.content
7373
# NOTE: for providing spans, we currently consider entire generated text as a span.
7474
# This is because, at the time of writing, the generative guardrail models does not

0 commit comments

Comments
 (0)