Skip to content

Commit 51945da

Browse files
authored
Merge pull request #31 from gkumbhat/add_unit_test
Fix issue with granite guardian not working and add unit tests
2 parents 134c816 + 5c3c21a commit 51945da

File tree

3 files changed

+134
-2
lines changed

3 files changed

+134
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vllm-detector-adapter"
3-
version = "0.0.1"
3+
version = "0.4.2"
44
authors = [
55
{ name="Gaurav Kumbhat", email="[email protected]" },
66
{ name="Evaline Ju", email="[email protected]" },

tests/test_detector_dispatcher.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Third Party
2+
import pytest
3+
4+
# Local
5+
from vllm_detector_adapter.detector_dispatcher import detector_dispatcher
6+
7+
### Global Declarations ########################################################
8+
9+
10+
@detector_dispatcher(types=["foo"])
11+
def add(data):
12+
return data + 1
13+
14+
15+
@detector_dispatcher(types=["bar"])
16+
def add(data):
17+
return str(data) + " 1"
18+
19+
20+
@detector_dispatcher(types=["foo", "bar"])
21+
def diff(data):
22+
return data - 1
23+
24+
25+
class Foo:
26+
def __init__(self, x):
27+
self.x = x
28+
29+
@detector_dispatcher(types=["foo"])
30+
def add(self, data):
31+
return data + 2
32+
33+
@detector_dispatcher(types=["bar"])
34+
def add(self, data):
35+
return str(data) + " 2"
36+
37+
38+
### Tests #####################################################################
39+
40+
41+
def test_detector_dispatching_to_module_functions():
42+
"Test that correct function gets called when each are of different type"
43+
assert add(1, fn_type="foo") == 2
44+
assert add(1, fn_type="bar") == "1 1"
45+
46+
47+
def test_detector_dispatching_to_module_functions_2_types():
48+
"Test that correct function gets called when function is of 2 types"
49+
assert diff(1, fn_type="foo") == 0
50+
assert diff(1, fn_type="bar") == 0
51+
52+
53+
def test_detector_dispatching_for_methods_works():
54+
"Test that correct instance method gets called when each are of different type"
55+
instance = Foo(1)
56+
assert instance.add(1, fn_type="foo") == 3
57+
assert instance.add(1, fn_type="bar") == "1 2"
58+
59+
60+
def test_detector_dispatching_same_name_method_across_classes():
61+
"Test that correct instance method gets called when each are of different type"
62+
63+
# Create duplicate class with same method as Foo
64+
class Foo2:
65+
def __init__(self, x):
66+
self.x = x
67+
68+
@detector_dispatcher(types=["foo"])
69+
def add(self, data):
70+
return data + 20
71+
72+
instance_1 = Foo(1)
73+
instance_2 = Foo2(1)
74+
75+
assert instance_1.add(1, fn_type="foo") == 3
76+
assert instance_2.add(1, fn_type="foo") == 21
77+
78+
79+
def test_detector_dispatching_same_name_method_for_child_classes():
80+
"""Test that correct instance method gets recognized when using a subclass but
81+
decorator is applied in base class function.
82+
"""
83+
84+
class Foo3(Foo):
85+
@detector_dispatcher(types=["foo"])
86+
def add(self, *args, **kwargs):
87+
return super().add(*args, **kwargs, fn_type="foo")
88+
89+
@detector_dispatcher(types=["bar"])
90+
def add(self, data):
91+
return data + 30
92+
93+
instance = Foo3(1)
94+
95+
assert instance.add(1, fn_type="foo") == 3
96+
97+
98+
### Error Tests #############################################################
99+
100+
101+
def test_decorator_erroring_with_no_type_available_fn():
102+
"Test that an error is raised when function to given type is not available"
103+
with pytest.raises(ValueError):
104+
add(1, fn_type="baz")
105+
106+
107+
def test_decorator_error_with_no_fn_type_provided():
108+
"Test that an error is raised when no fn_type is provided when calling the function"
109+
with pytest.raises(ValueError):
110+
add(1)
111+
112+
113+
def test_decorator_erroring_with_duplicate_type_assignment():
114+
"""Test that an error is raised when same type is assigned to multiple functions
115+
of same name in same module
116+
"""
117+
with pytest.raises(ValueError):
118+
119+
@detector_dispatcher()
120+
def add(data, fn_type="foo"):
121+
pass

vllm_detector_adapter/generative_detectors/granite_guardian.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,21 @@ def _request_to_chat_completion_request(
144144

145145
##### General request / response processing functions ##################
146146

147+
@detector_dispatcher(types=[DetectorType.TEXT_CONTENT])
148+
def preprocess_request(self, *args, **kwargs):
149+
# FIXME: This function declaration is temporary and should be removed once we fix following
150+
# issue with decorator:
151+
# ISSUE: Because of inheritance, the base class function with same name gets overriden by the function
152+
# declared below for preprocessing TEXT_CHAT type detectors. This fails the validation inside
153+
# the detector_dispatcher decorator.
154+
return super().preprocess_request(
155+
*args, **kwargs, fn_type=DetectorType.TEXT_CONTENT
156+
)
157+
147158
# Used detector_dispatcher decorator to allow for the same function to be called
148159
# for different types of detectors with different request types etc.
149160
@detector_dispatcher(types=[DetectorType.TEXT_CHAT])
150-
def preprocess_request(
161+
def preprocess_request( # noqa: F811
151162
self, request: ChatDetectionRequest
152163
) -> Union[ChatDetectionRequest, ErrorResponse]:
153164
"""Granite guardian chat request preprocess is just detector parameter updates"""

0 commit comments

Comments
 (0)