Skip to content

Commit 9a66941

Browse files
authored
Merge pull request #79 from evaline-ju/llama-consolidate
♻️ Consolidate llama guard content detection
2 parents 57fc9b0 + d74b654 commit 9a66941

File tree

1 file changed

+3
-75
lines changed

1 file changed

+3
-75
lines changed

vllm_detector_adapter/generative_detectors/llama_guard.py

Lines changed: 3 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Standard
22
from http import HTTPStatus
3-
from typing import Optional
4-
import asyncio
3+
from typing import Optional, Union
54

65
# Third Party
76
from fastapi import Request
@@ -13,9 +12,7 @@
1312
from vllm_detector_adapter.protocol import (
1413
ContentsDetectionRequest,
1514
ContentsDetectionResponse,
16-
ContentsDetectionResponseObject,
1715
)
18-
from vllm_detector_adapter.utils import DetectorType
1916

2017
logger = init_logger(__name__)
2118

@@ -122,19 +119,9 @@ async def content_analysis(
122119
self,
123120
request: ContentsDetectionRequest,
124121
raw_request: Optional[Request] = None,
125-
):
122+
) -> Union[ContentsDetectionResponse, ErrorResponse]:
126123
"""Function used to call chat detection and provide a /text/contents response"""
127124

128-
# Apply task template if it exists
129-
if self.task_template:
130-
request = self.apply_task_template(
131-
request, fn_type=DetectorType.TEXT_CONTENT
132-
)
133-
if isinstance(request, ErrorResponse):
134-
# Propagate any request problems that will not allow
135-
# task template to be applied
136-
return request
137-
138125
# Because conversation roles are expected to alternate between 'user' and 'assistant'
139126
# validate whether role_override was passed as a detector_param, which is invalid
140127
# since explicitly overriding the conversation roles will result in an error.
@@ -145,63 +132,4 @@ async def content_analysis(
145132
code=HTTPStatus.BAD_REQUEST.value,
146133
)
147134

148-
# Since separate batch processing function doesn't exist at the time of writing,
149-
# we are just going to collect all the text from content request and fire up
150-
# separate requests and wait asynchronously.
151-
# This mirrors how batching is handled in run_batch function in entrypoints/openai/
152-
# in vLLM codebase.
153-
completion_requests = self.preprocess_request(
154-
request, fn_type=DetectorType.TEXT_CONTENT
155-
)
156-
157-
# Send all the completion requests asynchronously.
158-
tasks = [
159-
asyncio.create_task(
160-
self.process_chat_completion_with_scores(
161-
completion_request, raw_request
162-
)
163-
)
164-
for completion_request in completion_requests
165-
]
166-
167-
# Gather all the results
168-
# NOTE: The results are guaranteed to be in order of requests
169-
results = await asyncio.gather(*tasks)
170-
171-
# If there is any error, return that otherwise, return the whole response
172-
# properly formatted.
173-
processed_result = []
174-
for result_idx, result in enumerate(results):
175-
# NOTE: we are only sending 1 of the error results
176-
# and not every one (not cumulative)
177-
if isinstance(result, ErrorResponse):
178-
return result
179-
else:
180-
# Process results to split out safety categories into separate objects
181-
(
182-
response,
183-
new_scores,
184-
detection_type,
185-
metadata_per_choice,
186-
) = await self.post_process_completion_results(*result)
187-
188-
new_result = (
189-
ContentsDetectionResponseObject.from_chat_completion_response(
190-
response,
191-
new_scores,
192-
detection_type,
193-
request.contents[result_idx],
194-
metadata_per_choice=metadata_per_choice,
195-
)
196-
)
197-
198-
# Verify whether the new_result is the correct is an errorresponse, and if so, return the errorresponse
199-
if isinstance(new_result, ErrorResponse):
200-
logger.debug(
201-
f"[content_analysis] ErrorResponse returned: {repr(new_result)}"
202-
)
203-
return new_result
204-
205-
processed_result.append(new_result)
206-
207-
return ContentsDetectionResponse(root=processed_result)
135+
return await super().content_analysis(request, raw_request)

0 commit comments

Comments
 (0)