Skip to content

Commit 72a316d

Browse files
committed
feat(core): Multiple endpoint types on benchmark and endpoint end
Signed-off-by: Tomasz Grzegorzek <[email protected]>
1 parent 1135472 commit 72a316d

File tree

3 files changed

+117
-13
lines changed

3 files changed

+117
-13
lines changed

packages/nemo-evaluator/src/nemo_evaluator/api/api_dataclasses.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
from enum import Enum
17-
from typing import Any, Dict, Optional
17+
from typing import Any, Dict, Optional, Union
1818

1919
import jinja2
2020
from pydantic import BaseModel, ConfigDict, Field
@@ -48,7 +48,7 @@ class ApiEndpoint(BaseModel):
4848
stream: Optional[bool] = Field(
4949
description="Whether responses should be streamed", default=None
5050
)
51-
type: Optional[EndpointType] = Field(
51+
type: Optional[Union[EndpointType, list[EndpointType]]] = Field(
5252
description="The type of the target", default=None
5353
)
5454
url: Optional[str] = Field(description="Url of the model", default=None)
@@ -108,7 +108,9 @@ class EvaluationConfig(BaseModel):
108108
params: Optional[ConfigParams] = Field(
109109
description="Parameters to be used for evaluation", default=None
110110
)
111-
supported_endpoint_types: Optional[list[str]] = Field(
111+
supported_endpoint_types: Optional[
112+
Union[EndpointType, list[Union[EndpointType, list[EndpointType]]]]
113+
] = Field(
112114
description="Supported endpoint types like chat or completions", default=None
113115
)
114116
type: Optional[str] = Field(description="Type of the task", default=None)

packages/nemo-evaluator/src/nemo_evaluator/core/input.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -348,20 +348,38 @@ def get_evaluation(
348348

349349

350350
def check_type_compatibility(evaluation: Evaluation):
351-
if (
352-
evaluation.config.supported_endpoint_types is not None
353-
and evaluation.target.api_endpoint.type
354-
not in evaluation.config.supported_endpoint_types
355-
):
351+
# Model endpoint types must be checked against benchmark required capabilities.
352+
# All benchmark required capabilities must be present in model endpoint types.
353+
354+
# If the evaluation does not specify particular endpoint types,
355+
# we treat it as 'any's
356+
357+
# We have to be carefull in terms of types. We might run into turning a stringable
358+
# dataclass into a set
359+
if evaluation.config.supported_endpoint_types is not None:
360+
if not isinstance(evaluation.target.api_endpoint.type, list):
361+
evaluation.target.api_endpoint.type = [evaluation.target.api_endpoint.type]
362+
363+
if not isinstance(evaluation.config.supported_endpoint_types, list):
364+
evaluation.config.supported_endpoint_types = [
365+
evaluation.config.supported_endpoint_types
366+
]
367+
model_types = set(evaluation.target.api_endpoint.type)
368+
is_target_compatible = False
369+
for benchmark_type_combination in evaluation.config.supported_endpoint_types:
370+
if not isinstance(benchmark_type_combination, list):
371+
benchmark_type_combination = [benchmark_type_combination]
372+
373+
if model_types.issuperset(set(benchmark_type_combination)):
374+
is_target_compatible = True
375+
356376
if evaluation.target.api_endpoint.type is None:
357377
raise MisconfigurationError(
358378
"target.api_endpoint.type should be defined and match one of the endpoint "
359379
f"types supported by the benchmark: '{evaluation.config.supported_endpoint_types}'",
360380
)
361-
if (
362-
evaluation.target.api_endpoint.type
363-
not in evaluation.config.supported_endpoint_types
364-
):
381+
382+
if not is_target_compatible:
365383
raise MisconfigurationError(
366384
f"The benchmark '{evaluation.config.type}' does not support the model type '{evaluation.target.api_endpoint.type}'. "
367385
f"The benchmark supports '{evaluation.config.supported_endpoint_types}'."

packages/nemo-evaluator/tests/unit_tests/core/test_input.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,17 @@
1414
# limitations under the License.
1515

1616

17-
from nemo_evaluator.core.input import merge_dicts
17+
import pytest
18+
19+
from nemo_evaluator.api.api_dataclasses import (
20+
ApiEndpoint,
21+
EndpointType,
22+
Evaluation,
23+
EvaluationConfig,
24+
EvaluationTarget,
25+
)
26+
from nemo_evaluator.core.input import check_type_compatibility, merge_dicts
27+
from nemo_evaluator.core.utils import MisconfigurationError
1828

1929

2030
def test_distinct_keys():
@@ -64,3 +74,77 @@ def test_empty_dicts():
6474
d2 = {}
6575
assert merge_dicts(d1, d2) == {"b": 2}
6676
assert merge_dicts({}, {}) == {}
77+
78+
79+
@pytest.mark.parametrize(
80+
"model_types,benchmark_types",
81+
[
82+
(EndpointType.CHAT, EndpointType.CHAT),
83+
([EndpointType.CHAT], [EndpointType.CHAT]),
84+
(EndpointType.CHAT, [EndpointType.CHAT]),
85+
([EndpointType.CHAT], EndpointType.CHAT),
86+
("chat", "chat"),
87+
("chat", None),
88+
([EndpointType.CHAT, EndpointType.COMPLETIONS], [EndpointType.CHAT]),
89+
([EndpointType.CHAT, EndpointType.COMPLETIONS], EndpointType.CHAT),
90+
(EndpointType.CHAT, [[EndpointType.CHAT], [EndpointType.COMPLETIONS]]),
91+
(
92+
[EndpointType.CHAT, EndpointType.COMPLETIONS],
93+
[EndpointType.CHAT, EndpointType.COMPLETIONS],
94+
),
95+
(
96+
[EndpointType.CHAT, EndpointType.COMPLETIONS, EndpointType.VLM],
97+
[EndpointType.CHAT, EndpointType.COMPLETIONS],
98+
),
99+
(
100+
[EndpointType.CHAT, EndpointType.VLM],
101+
[
102+
[EndpointType.COMPLETIONS, EndpointType.VLM],
103+
[EndpointType.CHAT, EndpointType.VLM],
104+
],
105+
),
106+
],
107+
)
108+
def test_endpoint_type_single_compatible(model_types, benchmark_types):
109+
evaluation_config = EvaluationConfig(supported_endpoint_types=benchmark_types)
110+
target_config = EvaluationTarget(
111+
api_endpoint=ApiEndpoint(type=model_types, url="localhost", model_id="my_model")
112+
)
113+
evaluation = Evaluation(
114+
config=evaluation_config,
115+
target=target_config,
116+
command="",
117+
pkg_name="",
118+
framework_name="",
119+
)
120+
check_type_compatibility(evaluation)
121+
122+
123+
@pytest.mark.parametrize(
124+
"model_types,benchmark_types",
125+
[
126+
(EndpointType.CHAT, EndpointType.COMPLETIONS),
127+
("chat", "vlm"),
128+
([EndpointType.CHAT], [[EndpointType.CHAT, EndpointType.VLM]]),
129+
(
130+
[EndpointType.CHAT, EndpointType.VLM],
131+
[[EndpointType.COMPLETIONS, EndpointType.VLM]],
132+
),
133+
],
134+
)
135+
def test_endpoint_type_single_incompatible(model_types, benchmark_types):
136+
evaluation_config = EvaluationConfig(supported_endpoint_types=benchmark_types)
137+
target_config = EvaluationTarget(
138+
api_endpoint=ApiEndpoint(type=model_types, url="localhost", model_id="my_model")
139+
)
140+
evaluation = Evaluation(
141+
config=evaluation_config,
142+
target=target_config,
143+
command="",
144+
pkg_name="",
145+
framework_name="",
146+
)
147+
with pytest.raises(
148+
MisconfigurationError, match=r".* does not support the model type .*"
149+
):
150+
check_type_compatibility(evaluation)

0 commit comments

Comments
 (0)