Skip to content

Commit ba406e2

Browse files
Introducing prompt benchmarking (#497)
* Add base support for benchmarking models with config * bugFix: config normalization * TypoFix: add 's' to configuration * add display name in get_scores * add tests for prompt benchmark * uncomment first benchmark test
1 parent ea51442 commit ba406e2

File tree

5 files changed

+96
-16
lines changed

5 files changed

+96
-16
lines changed

aixplain/factories/benchmark_factory.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"""
2323

2424
import logging
25-
from typing import Dict, List, Text
25+
from typing import Dict, List, Text, Any, Tuple
2626
import json
2727
from aixplain.enums.supplier import Supplier
2828
from aixplain.modules import Dataset, Metric, Model
@@ -150,9 +150,9 @@ def _validate_create_benchmark_payload(cls, payload):
150150
if len(payload["datasets"]) != 1:
151151
raise Exception("Please use exactly one dataset")
152152
if len(payload["metrics"]) == 0:
153-
raise Exception("Please use exactly one metric")
154-
if len(payload["model"]) == 0:
155-
raise Exception("Please use exactly one model")
153+
raise Exception("Please use at least one metric")
154+
if len(payload["model"]) == 0 and payload.get("models", None) is None:
155+
raise Exception("Please use at least one model")
156156
clean_metrics_info = {}
157157
for metric_info in payload["metrics"]:
158158
metric_id = metric_info["id"]
@@ -167,6 +167,31 @@ def _validate_create_benchmark_payload(cls, payload):
167167
{"id": metric_id, "configurations": metric_config} for metric_id, metric_config in clean_metrics_info.items()
168168
]
169169
return payload
170+
171+
@classmethod
172+
def _reformat_model_list(cls, model_list: List[Model]) -> Tuple[List[Any], List[Any]]:
173+
"""Reformat the model list to be used in the create benchmark API
174+
175+
Args:
176+
model_list (List[Model]): List of models to be used in the benchmark
177+
178+
Returns:
179+
Tuple[List[Any], List[Any]]: Reformatted model lists
180+
181+
"""
182+
model_list_without_parms, model_list_with_parms = [], []
183+
for model in model_list:
184+
if "displayName" in model.additional_info:
185+
model_list_with_parms.append({"id": model.id, "displayName": model.additional_info["displayName"], "configurations": json.dumps(model.additional_info["configuration"])})
186+
else:
187+
model_list_without_parms.append(model.id)
188+
if len(model_list_with_parms) > 0:
189+
if len(model_list_without_parms) > 0:
190+
raise Exception("Please provide addditional info for all models or for none of the models")
191+
else:
192+
model_list_with_parms = None
193+
return model_list_without_parms, model_list_with_parms
194+
170195

171196
@classmethod
172197
def create(cls, name: str, dataset_list: List[Dataset], model_list: List[Model], metric_list: List[Metric]) -> Benchmark:
@@ -186,15 +211,18 @@ def create(cls, name: str, dataset_list: List[Dataset], model_list: List[Model],
186211
try:
187212
url = urljoin(cls.backend_url, "sdk/benchmarks")
188213
headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"}
214+
model_list_without_parms, model_list_with_parms = cls._reformat_model_list(model_list)
189215
payload = {
190216
"name": name,
191217
"datasets": [dataset.id for dataset in dataset_list],
192-
"model": [model.id for model in model_list],
193218
"metrics": [{"id": metric.id, "configurations": metric.normalization_options} for metric in metric_list],
219+
"model": model_list_without_parms,
194220
"shapScores": [],
195221
"humanEvaluationReport": False,
196222
"automodeTraining": False,
197223
}
224+
if model_list_with_parms is not None:
225+
payload["models"] = model_list_with_parms
198226
clean_payload = cls._validate_create_benchmark_payload(payload)
199227
payload = json.dumps(clean_payload)
200228
r = _request_with_retry("post", url, headers=headers, data=payload)

aixplain/modules/benchmark_job.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from aixplain.utils import config
44
from urllib.parse import urljoin
55
import pandas as pd
6+
import json
67
from pathlib import Path
78
from aixplain.utils.request_utils import _request_with_retry
89
from aixplain.utils.file_utils import save_file
@@ -109,6 +110,10 @@ def get_scores(self, return_simplified=True, return_as_dataframe=True):
109110
scores = {}
110111
for iteration_info in iterations:
111112
model_id = iteration_info["pipeline"]
113+
pipeline_json = json.loads(iteration_info["pipelineJson"])
114+
if "benchmark" in pipeline_json:
115+
model_id = pipeline_json["benchmark"]["displayName"]
116+
112117
model_info = {
113118
"creditsUsed": round(iteration_info.get("credits", 0), 5),
114119
"timeSpent": round(iteration_info.get("runtime", 0), 2),

aixplain/modules/model/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,16 @@ def delete(self) -> None:
432432
message = "Model Deletion Error: Make sure the model exists and you are the owner."
433433
logging.error(message)
434434
raise Exception(f"{message}")
435+
436+
def add_additional_info_for_benchmark(self, display_name: str, configuration: Dict) -> None:
437+
"""Add additional info for benchmark
438+
439+
Args:
440+
display_name (str): display name of the model
441+
configuration (Dict): configuration of the model
442+
"""
443+
self.additional_info["displayName"] = display_name
444+
self.additional_info["configuration"] = configuration
435445

436446
@classmethod
437447
def from_dict(cls, data: Dict) -> "Model":
@@ -451,3 +461,4 @@ def from_dict(cls, data: Dict) -> "Model":
451461
model_params=data.get("model_params"),
452462
**data.get("additional_info", {}),
453463
)
464+

tests/functional/benchmark/benchmark_functional_test.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
from pathlib import Path
1212

1313
import pytest
14-
1514
import logging
16-
1715
from aixplain import aixplain_v2 as v2
1816

1917
logger = logging.getLogger()
@@ -22,6 +20,7 @@
2220
TIMEOUT = 60 * 30
2321
RUN_FILE = str(Path(r"tests/functional/benchmark/data/benchmark_test_run_data.json"))
2422
MODULE_FILE = str(Path(r"tests/functional/benchmark/data/benchmark_module_test_data.json"))
23+
RUN_WITH_PARAMETERS_FILE = str(Path(r"tests/functional/benchmark/data/benchmark_test_with_parameters.json"))
2524

2625

2726
def read_data(data_path):
@@ -33,6 +32,11 @@ def run_input_map(request):
3332
return request.param
3433

3534

35+
@pytest.fixture(scope="module", params=[(name, params) for name, params in read_data(RUN_WITH_PARAMETERS_FILE).items()])
36+
def run_with_parameters_input_map(request):
37+
return request.param
38+
39+
3640
@pytest.fixture(scope="module", params=read_data(MODULE_FILE))
3741
def module_input_map(request):
3842
return request.param
@@ -79,12 +83,22 @@ def test_create_and_run(run_input_map, BenchmarkFactory):
7983
assert_correct_results(benchmark_job)
8084

8185

82-
# def test_module(module_input_map):
83-
# benchmark = BenchmarkFactory.get(module_input_map["benchmark_id"])
84-
# assert benchmark.id == module_input_map["benchmark_id"]
85-
# benchmark_job = benchmark.job_list[0]
86-
# assert benchmark_job.benchmark_id == module_input_map["benchmark_id"]
87-
# job_status = benchmark_job.check_status()
88-
# assert job_status in ["in_progress", "completed"]
89-
# df = benchmark_job.download_results_as_csv(return_dataframe=True)
90-
# assert type(df) is pd.DataFrame
86+
@pytest.mark.parametrize("BenchmarkFactory", [BenchmarkFactory, v2.Benchmark])
87+
def test_create_and_run_with_parameters(run_with_parameters_input_map, BenchmarkFactory):
88+
name, params = run_with_parameters_input_map
89+
model_list = []
90+
for model_info in params["models_with_parameters"]:
91+
model = ModelFactory.get(model_info["model_id"])
92+
model.add_additional_info_for_benchmark(display_name=model_info["display_name"], configuration=model_info["configuration"])
93+
model_list.append(model)
94+
dataset_list = [DatasetFactory.list(query=dataset_name)["results"][0] for dataset_name in params["dataset_names"]]
95+
metric_list = [MetricFactory.get(metric_id) for metric_id in params["metric_ids"]]
96+
benchmark = BenchmarkFactory.create(f"SDK Benchmark Test With Parameters({name}) {uuid.uuid4()}", dataset_list, model_list, metric_list)
97+
assert type(benchmark) is Benchmark, "Couldn't create benchmark"
98+
benchmark_job = benchmark.start()
99+
assert type(benchmark_job) is BenchmarkJob, "Couldn't start job"
100+
assert is_job_finshed(benchmark_job), "Job did not finish in time"
101+
assert_correct_results(benchmark_job)
102+
103+
104+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"Translation With LLMs": {
3+
"models_with_parameters": [
4+
{
5+
"model_id": "669a63646eb56306647e1091",
6+
"display_name": "EnHi LLM",
7+
"configuration": {
8+
"prompt": "Translate the following text into Hindi."
9+
}
10+
},
11+
{
12+
"model_id": "669a63646eb56306647e1091",
13+
"display_name": "EnEs LLM",
14+
"configuration": {
15+
"prompt": "Translate the following text into Spanish."
16+
}
17+
}
18+
],
19+
"dataset_names": ["EnHi SDK Test - Benchmark Dataset"],
20+
"metric_ids": ["639874ab506c987b1ae1acc6", "6408942f166427039206d71e"]
21+
}
22+
}

0 commit comments

Comments
 (0)