|
2 | 2 | import pandas as pd |
3 | 3 | import json |
4 | 4 | from dotenv import load_dotenv |
| 5 | +import time |
5 | 6 |
|
6 | 7 | load_dotenv() |
7 | 8 | from aixplain.factories import ModelFactory, DatasetFactory, MetricFactory, BenchmarkFactory |
@@ -34,23 +35,55 @@ def run_input_map(request): |
34 | 35 | def module_input_map(request): |
35 | 36 | return request.param |
36 | 37 |
|
| 38 | +def is_job_finshed(benchmark_job): |
| 39 | + time_taken = 0 |
| 40 | + sleep_time = 15 |
| 41 | + timeout = 10 * 60 |
| 42 | + while True: |
| 43 | + if time_taken > timeout: |
| 44 | + break |
| 45 | + job_status = benchmark_job.check_status() |
| 46 | + if job_status == "in_progress": |
| 47 | + time.sleep(sleep_time) |
| 48 | + time_taken += sleep_time |
| 49 | + elif job_status == "completed": |
| 50 | + return True |
| 51 | + else: |
| 52 | + break |
| 53 | + return False |
37 | 54 |
|
38 | | -def test_run(run_input_map): |
| 55 | +def assert_correct_results(benchmark_job): |
| 56 | + df = benchmark_job.download_results_as_csv(return_dataframe=True) |
| 57 | + assert type(df) is pd.DataFrame, "Couldn't download CSV" |
| 58 | + model_success_rate = (sum(df["Model_success"])*100)/len(df.index) |
| 59 | + assert model_success_rate > 80 , f"Low model success rate ({model_success_rate})" |
| 60 | + metric_name = "BLEU by sacrebleu" |
| 61 | + mean_score = df[metric_name].mean() |
| 62 | + assert mean_score != 0 , f"Zero Mean Score - Please check metric ({metric_name})" |
| 63 | + |
| 64 | + |
| 65 | + |
| 66 | +def test_create_and_run(run_input_map): |
39 | 67 | model_list = [ModelFactory.get(model_id) for model_id in run_input_map["model_ids"]] |
40 | | - dataset_list = [DatasetFactory.get(dataset_id) for dataset_id in run_input_map["dataset_ids"]] |
| 68 | + dataset_list = [DatasetFactory.list(query=dataset_name)["results"][0] for dataset_name in run_input_map["dataset_names"]] |
41 | 69 | metric_list = [MetricFactory.get(metric_id) for metric_id in run_input_map["metric_ids"]] |
42 | 70 | benchmark = BenchmarkFactory.create(f"SDK Benchmark Test {uuid.uuid4()}", dataset_list, model_list, metric_list) |
43 | | - assert type(benchmark) is Benchmark |
| 71 | + assert type(benchmark) is Benchmark, "Couldn't create benchmark" |
44 | 72 | benchmark_job = benchmark.start() |
45 | | - assert type(benchmark_job) is BenchmarkJob |
| 73 | + assert type(benchmark_job) is BenchmarkJob, "Couldn't start job" |
| 74 | + assert is_job_finshed(benchmark_job), "Job did not finish in time" |
| 75 | + assert_correct_results(benchmark_job) |
46 | 76 |
|
47 | 77 |
|
48 | | -def test_module(module_input_map): |
49 | | - benchmark = BenchmarkFactory.get(module_input_map["benchmark_id"]) |
50 | | - assert benchmark.id == module_input_map["benchmark_id"] |
51 | | - benchmark_job = benchmark.job_list[0] |
52 | | - assert benchmark_job.benchmark_id == module_input_map["benchmark_id"] |
53 | | - job_status = benchmark_job.check_status() |
54 | | - assert job_status in ["in_progress", "completed"] |
55 | | - df = benchmark_job.download_results_as_csv(return_dataframe=True) |
56 | | - assert type(df) is pd.DataFrame |
| 78 | + |
| 79 | + |
| 80 | + |
| 81 | +# def test_module(module_input_map): |
| 82 | +# benchmark = BenchmarkFactory.get(module_input_map["benchmark_id"]) |
| 83 | +# assert benchmark.id == module_input_map["benchmark_id"] |
| 84 | +# benchmark_job = benchmark.job_list[0] |
| 85 | +# assert benchmark_job.benchmark_id == module_input_map["benchmark_id"] |
| 86 | +# job_status = benchmark_job.check_status() |
| 87 | +# assert job_status in ["in_progress", "completed"] |
| 88 | +# df = benchmark_job.download_results_as_csv(return_dataframe=True) |
| 89 | +# assert type(df) is pd.DataFrame |
0 commit comments