Skip to content

Commit b53ca78

Browse files
M 5608818958 update benchmark tests (#98)
* Update benchmark_functional_test.py * Update benchmark_test_run_data.json
1 parent f668332 commit b53ca78

File tree

2 files changed

+48
-14
lines changed

2 files changed

+48
-14
lines changed

tests/functional/benchmark/benchmark_functional_test.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pandas as pd
33
import json
44
from dotenv import load_dotenv
5+
import time
56

67
load_dotenv()
78
from aixplain.factories import ModelFactory, DatasetFactory, MetricFactory, BenchmarkFactory
@@ -34,23 +35,55 @@ def run_input_map(request):
3435
def module_input_map(request):
3536
return request.param
3637

38+
def is_job_finshed(benchmark_job):
39+
time_taken = 0
40+
sleep_time = 15
41+
timeout = 10 * 60
42+
while True:
43+
if time_taken > timeout:
44+
break
45+
job_status = benchmark_job.check_status()
46+
if job_status == "in_progress":
47+
time.sleep(sleep_time)
48+
time_taken += sleep_time
49+
elif job_status == "completed":
50+
return True
51+
else:
52+
break
53+
return False
3754

38-
def test_run(run_input_map):
55+
def assert_correct_results(benchmark_job):
56+
df = benchmark_job.download_results_as_csv(return_dataframe=True)
57+
assert type(df) is pd.DataFrame, "Couldn't download CSV"
58+
model_success_rate = (sum(df["Model_success"])*100)/len(df.index)
59+
assert model_success_rate > 80 , f"Low model success rate ({model_success_rate})"
60+
metric_name = "BLEU by sacrebleu"
61+
mean_score = df[metric_name].mean()
62+
assert mean_score != 0 , f"Zero Mean Score - Please check metric ({metric_name})"
63+
64+
65+
66+
def test_create_and_run(run_input_map):
3967
model_list = [ModelFactory.get(model_id) for model_id in run_input_map["model_ids"]]
40-
dataset_list = [DatasetFactory.get(dataset_id) for dataset_id in run_input_map["dataset_ids"]]
68+
dataset_list = [DatasetFactory.list(query=dataset_name)["results"][0] for dataset_name in run_input_map["dataset_names"]]
4169
metric_list = [MetricFactory.get(metric_id) for metric_id in run_input_map["metric_ids"]]
4270
benchmark = BenchmarkFactory.create(f"SDK Benchmark Test {uuid.uuid4()}", dataset_list, model_list, metric_list)
43-
assert type(benchmark) is Benchmark
71+
assert type(benchmark) is Benchmark, "Couldn't create benchmark"
4472
benchmark_job = benchmark.start()
45-
assert type(benchmark_job) is BenchmarkJob
73+
assert type(benchmark_job) is BenchmarkJob, "Couldn't start job"
74+
assert is_job_finshed(benchmark_job), "Job did not finish in time"
75+
assert_correct_results(benchmark_job)
4676

4777

48-
def test_module(module_input_map):
49-
benchmark = BenchmarkFactory.get(module_input_map["benchmark_id"])
50-
assert benchmark.id == module_input_map["benchmark_id"]
51-
benchmark_job = benchmark.job_list[0]
52-
assert benchmark_job.benchmark_id == module_input_map["benchmark_id"]
53-
job_status = benchmark_job.check_status()
54-
assert job_status in ["in_progress", "completed"]
55-
df = benchmark_job.download_results_as_csv(return_dataframe=True)
56-
assert type(df) is pd.DataFrame
78+
79+
80+
81+
# def test_module(module_input_map):
82+
# benchmark = BenchmarkFactory.get(module_input_map["benchmark_id"])
83+
# assert benchmark.id == module_input_map["benchmark_id"]
84+
# benchmark_job = benchmark.job_list[0]
85+
# assert benchmark_job.benchmark_id == module_input_map["benchmark_id"]
86+
# job_status = benchmark_job.check_status()
87+
# assert job_status in ["in_progress", "completed"]
88+
# df = benchmark_job.download_results_as_csv(return_dataframe=True)
89+
# assert type(df) is pd.DataFrame

tests/functional/benchmark/data/benchmark_test_run_data.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
{
33
"model_ids": ["61b097551efecf30109d32da", "60ddefbe8d38c51c5885f98a"],
44
"dataset_ids": ["64da34a813d879bec2323aa3"],
5+
"dataset_names": ["EnHi SDK Test - Benchmark Dataset"],
56
"metric_ids": ["639874ab506c987b1ae1acc6", "6408942f166427039206d71e"]
67
}
7-
]
8+
]

0 commit comments

Comments
 (0)