Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

from enum import Enum
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import jinja2
from pydantic import BaseModel, ConfigDict, Field
Expand Down Expand Up @@ -48,7 +48,7 @@ class ApiEndpoint(BaseModel):
stream: Optional[bool] = Field(
description="Whether responses should be streamed", default=None
)
type: Optional[EndpointType] = Field(
type: Optional[Union[EndpointType, list[EndpointType]]] = Field(
description="The type of the target", default=None
)
url: Optional[str] = Field(description="Url of the model", default=None)
Expand Down Expand Up @@ -108,7 +108,9 @@ class EvaluationConfig(BaseModel):
params: Optional[ConfigParams] = Field(
description="Parameters to be used for evaluation", default=None
)
supported_endpoint_types: Optional[list[str]] = Field(
supported_endpoint_types: Optional[
Union[EndpointType, list[Union[EndpointType, list[EndpointType]]]]
] = Field(
description="Supported endpoint types like chat or completions", default=None
)
type: Optional[str] = Field(description="Type of the task", default=None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def get_args() -> argparse.Namespace:
parser_run.add_argument(
"--model_type",
type=str,
nargs="+",
help="Run config.: endpoint type",
choices=["chat", "completions", "vlm", "embedding"],
)
Expand Down
67 changes: 53 additions & 14 deletions packages/nemo-evaluator/src/nemo_evaluator/core/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from nemo_evaluator.adapters.adapter_config import AdapterConfig
from nemo_evaluator.api.api_dataclasses import (
EndpointType,
Evaluation,
EvaluationConfig,
EvaluationTarget,
Expand Down Expand Up @@ -347,26 +348,64 @@ def get_evaluation(
return Evaluation(**merged_configuration)


def _get_benchmark_types_readable(supported_endpoint_types: list[EndpointType]) -> str:
readable_form = ""
for supported_combination in supported_endpoint_types:
if not isinstance(supported_combination, list):
readable_form += f"\n- {supported_combination.value}"
else:
readable_form += (
f"\n- [{[m_type.value for m_type in supported_combination]}]"
)
return readable_form


def check_type_compatibility(evaluation: Evaluation):
if (
evaluation.config.supported_endpoint_types is not None
and evaluation.target.api_endpoint.type
not in evaluation.config.supported_endpoint_types
):
# Model endpoint types must be checked against benchmark required capabilities.
# All benchmark required capabilities must be present in model endpoint types.

# If the evaluation does not specify particular endpoint types,
# we treat it as 'any's

# We have to be carefull in terms of types. We might run into turning a stringable
# dataclass into a set
if evaluation.config.supported_endpoint_types is not None:
if evaluation.target.api_endpoint.type is None:
raise MisconfigurationError(
"target.api_endpoint.type should be defined and match one of the endpoint "
f"types supported by the benchmark: '{evaluation.config.supported_endpoint_types}'",
"target.api_endpoint.type (CLI: --model_type) should be defined and match one of the endpoint "
f"types supported by the benchmark: {_get_benchmark_types_readable(evaluation.config.supported_endpoint_types)}"
)
if (
evaluation.target.api_endpoint.type
not in evaluation.config.supported_endpoint_types
):
if not isinstance(evaluation.target.api_endpoint.type, list):
evaluation.target.api_endpoint.type = [evaluation.target.api_endpoint.type]

if not isinstance(evaluation.config.supported_endpoint_types, list):
evaluation.config.supported_endpoint_types = [
evaluation.config.supported_endpoint_types
]
model_types = set(evaluation.target.api_endpoint.type)
is_target_compatible = False
for benchmark_type_combination in evaluation.config.supported_endpoint_types:
if not isinstance(benchmark_type_combination, list):
benchmark_type_combination = [benchmark_type_combination]

if model_types.issuperset(set(benchmark_type_combination)):
if is_target_compatible:
raise MisconfigurationError(
f"The benchmark {evaluation.config.type} is compatible with more than one combination of model capabilities {evaluation.target.api_endpoint.type} and needs a specification. Please override model capabilities for this benchmark to match only one combination."
)
else:
is_target_compatible = True

if not is_target_compatible:
raise MisconfigurationError(
f"The benchmark '{evaluation.config.type}' does not support the model type '{evaluation.target.api_endpoint.type}'. "
f"The benchmark supports '{evaluation.config.supported_endpoint_types}'."
f"The benchmark '{evaluation.config.type}' does not support any of the model types '{evaluation.target.api_endpoint.type}'. \n"
f"The benchmark supports: {_get_benchmark_types_readable(evaluation.config.supported_endpoint_types)}"
)

# unpack types back. Listifying was for checking types only
evaluation.config.supported_endpoint_types = (
evaluation.config.supported_endpoint_types[0]
)
evaluation.target.api_endpoint.type = evaluation.target.api_endpoint.type[0]
if evaluation.target.api_endpoint.type:
# Check this only if the model is really required (to accomodate for non-model evals)
if evaluation.target.api_endpoint.url is None:
Expand Down
117 changes: 116 additions & 1 deletion packages/nemo-evaluator/tests/unit_tests/core/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,17 @@
# limitations under the License.


from nemo_evaluator.core.input import merge_dicts
import pytest

from nemo_evaluator.api.api_dataclasses import (
ApiEndpoint,
EndpointType,
Evaluation,
EvaluationConfig,
EvaluationTarget,
)
from nemo_evaluator.core.input import check_type_compatibility, merge_dicts
from nemo_evaluator.core.utils import MisconfigurationError


def test_distinct_keys():
Expand Down Expand Up @@ -64,3 +74,108 @@ def test_empty_dicts():
d2 = {}
assert merge_dicts(d1, d2) == {"b": 2}
assert merge_dicts({}, {}) == {}


@pytest.mark.parametrize(
"model_types,benchmark_types",
[
(EndpointType.CHAT, EndpointType.CHAT),
([EndpointType.CHAT], [EndpointType.CHAT]),
(EndpointType.CHAT, [EndpointType.CHAT]),
([EndpointType.CHAT], EndpointType.CHAT),
("chat", "chat"),
("chat", None),
([EndpointType.CHAT, EndpointType.COMPLETIONS], [EndpointType.CHAT]),
([EndpointType.CHAT, EndpointType.COMPLETIONS], EndpointType.CHAT),
(EndpointType.CHAT, [[EndpointType.CHAT], [EndpointType.COMPLETIONS]]),
(
[EndpointType.CHAT, EndpointType.VLM],
[
[EndpointType.COMPLETIONS, EndpointType.VLM],
[EndpointType.CHAT, EndpointType.VLM],
],
),
],
)
def test_endpoint_type_single_compatible(model_types, benchmark_types):
evaluation_config = EvaluationConfig(supported_endpoint_types=benchmark_types)
target_config = EvaluationTarget(
api_endpoint=ApiEndpoint(type=model_types, url="localhost", model_id="my_model")
)
evaluation = Evaluation(
config=evaluation_config,
target=target_config,
command="",
pkg_name="",
framework_name="",
)
check_type_compatibility(evaluation)


@pytest.mark.parametrize(
"model_types,benchmark_types",
[
(EndpointType.CHAT, EndpointType.COMPLETIONS),
("chat", "vlm"),
([EndpointType.CHAT], [[EndpointType.CHAT, EndpointType.VLM]]),
(
[EndpointType.CHAT, EndpointType.VLM],
[[EndpointType.COMPLETIONS, EndpointType.VLM]],
),
],
)
def test_endpoint_type_single_incompatible(model_types, benchmark_types):
evaluation_config = EvaluationConfig(supported_endpoint_types=benchmark_types)
target_config = EvaluationTarget(
api_endpoint=ApiEndpoint(type=model_types, url="localhost", model_id="my_model")
)
evaluation = Evaluation(
config=evaluation_config,
target=target_config,
command="",
pkg_name="",
framework_name="",
)
with pytest.raises(
MisconfigurationError, match=r".* does not support any of the model types .*"
):
check_type_compatibility(evaluation)


@pytest.mark.parametrize(
"model_types,benchmark_types",
[
(
[EndpointType.CHAT, EndpointType.COMPLETIONS],
[EndpointType.CHAT, EndpointType.COMPLETIONS],
),
(
[EndpointType.CHAT, EndpointType.COMPLETIONS, EndpointType.VLM],
[EndpointType.CHAT, EndpointType.COMPLETIONS],
),
(
[EndpointType.CHAT, EndpointType.COMPLETIONS, EndpointType.VLM],
[
[EndpointType.COMPLETIONS, EndpointType.VLM],
[EndpointType.CHAT, EndpointType.VLM],
],
),
],
)
def test_endpoint_type_raise_on_more_than_one(model_types, benchmark_types):
evaluation_config = EvaluationConfig(supported_endpoint_types=benchmark_types)
target_config = EvaluationTarget(
api_endpoint=ApiEndpoint(type=model_types, url="localhost", model_id="my_model")
)
evaluation = Evaluation(
config=evaluation_config,
target=target_config,
command="",
pkg_name="",
framework_name="",
)
with pytest.raises(
MisconfigurationError,
match=r".* is compatible with more than one combination of model capabilities .*",
):
check_type_compatibility(evaluation)
Loading