2222"""
2323
2424import logging
25- from typing import Dict , List , Text
25+ from typing import Dict , List , Text , Any , Tuple
2626import json
2727from aixplain .enums .supplier import Supplier
2828from aixplain .modules import Dataset , Metric , Model
@@ -150,9 +150,9 @@ def _validate_create_benchmark_payload(cls, payload):
150150 if len (payload ["datasets" ]) != 1 :
151151 raise Exception ("Please use exactly one dataset" )
152152 if len (payload ["metrics" ]) == 0 :
153- raise Exception ("Please use exactly one metric" )
154- if len (payload ["model" ]) == 0 :
155- raise Exception ("Please use exactly one model" )
153+ raise Exception ("Please use at least one metric" )
154+ if len (payload ["model" ]) == 0 and payload . get ( "models" , None ) is None :
155+ raise Exception ("Please use at least one model" )
156156 clean_metrics_info = {}
157157 for metric_info in payload ["metrics" ]:
158158 metric_id = metric_info ["id" ]
@@ -167,6 +167,31 @@ def _validate_create_benchmark_payload(cls, payload):
167167 {"id" : metric_id , "configurations" : metric_config } for metric_id , metric_config in clean_metrics_info .items ()
168168 ]
169169 return payload
170+
171+ @classmethod
172+ def _reformat_model_list (cls , model_list : List [Model ]) -> Tuple [List [Any ], List [Any ]]:
173+ """Reformat the model list to be used in the create benchmark API
174+
175+ Args:
176+ model_list (List[Model]): List of models to be used in the benchmark
177+
178+ Returns:
179+ Tuple[List[Any], List[Any]]: Reformatted model lists
180+
181+ """
182+ model_list_without_parms , model_list_with_parms = [], []
183+ for model in model_list :
184+ if "displayName" in model .additional_info :
185+ model_list_with_parms .append ({"id" : model .id , "displayName" : model .additional_info ["displayName" ], "configurations" : json .dumps (model .additional_info ["configuration" ])})
186+ else :
187+ model_list_without_parms .append (model .id )
188+ if len (model_list_with_parms ) > 0 :
189+ if len (model_list_without_parms ) > 0 :
190+ raise Exception ("Please provide addditional info for all models or for none of the models" )
191+ else :
192+ model_list_with_parms = None
193+ return model_list_without_parms , model_list_with_parms
194+
170195
171196 @classmethod
172197 def create (cls , name : str , dataset_list : List [Dataset ], model_list : List [Model ], metric_list : List [Metric ]) -> Benchmark :
@@ -186,15 +211,18 @@ def create(cls, name: str, dataset_list: List[Dataset], model_list: List[Model],
186211 try :
187212 url = urljoin (cls .backend_url , "sdk/benchmarks" )
188213 headers = {"Authorization" : f"Token { config .TEAM_API_KEY } " , "Content-Type" : "application/json" }
214+ model_list_without_parms , model_list_with_parms = cls ._reformat_model_list (model_list )
189215 payload = {
190216 "name" : name ,
191217 "datasets" : [dataset .id for dataset in dataset_list ],
192- "model" : [model .id for model in model_list ],
193218 "metrics" : [{"id" : metric .id , "configurations" : metric .normalization_options } for metric in metric_list ],
219+ "model" : model_list_without_parms ,
194220 "shapScores" : [],
195221 "humanEvaluationReport" : False ,
196222 "automodeTraining" : False ,
197223 }
224+ if model_list_with_parms is not None :
225+ payload ["models" ] = model_list_with_parms
198226 clean_payload = cls ._validate_create_benchmark_payload (payload )
199227 payload = json .dumps (clean_payload )
200228 r = _request_with_retry ("post" , url , headers = headers , data = payload )
0 commit comments