5
5
import json
6
6
import os
7
7
import traceback
8
+ from concurrent .futures import ThreadPoolExecutor
8
9
from dataclasses import fields
9
10
from datetime import datetime , timedelta
10
11
from itertools import chain
22
23
from ads .aqua import logger
23
24
from ads .aqua .common .entities import ModelConfigResult
24
25
from ads .aqua .common .enums import ConfigFolder , Tags
25
- from ads .aqua .common .errors import AquaRuntimeError , AquaValueError
26
+ from ads .aqua .common .errors import AquaValueError
26
27
from ads .aqua .common .utils import (
27
28
_is_valid_mvs ,
28
29
get_artifact_path ,
58
59
class AquaApp :
59
60
"""Base Aqua App to contain common components."""
60
61
62
+ MAX_WORKERS = 10 # Number of workers for asynchronous resource loading
63
+
61
64
@telemetry (name = "aqua" )
62
65
def __init__ (self ) -> None :
63
66
if OCI_RESOURCE_PRINCIPAL_VERSION :
@@ -128,20 +131,69 @@ def update_model_provenance(
128
131
update_model_provenance_details = update_model_provenance_details ,
129
132
)
130
133
131
- # TODO: refactor model evaluation implementation to use it.
132
134
@staticmethod
133
135
def get_source (source_id : str ) -> Union [ModelDeployment , DataScienceModel ]:
134
- if is_valid_ocid (source_id ):
135
- if "datasciencemodeldeployment" in source_id :
136
- return ModelDeployment .from_id (source_id )
137
- elif "datasciencemodel" in source_id :
138
- return DataScienceModel .from_id (source_id )
136
+ """
137
+ Fetches a model or model deployment based on the provided OCID.
138
+
139
+ Parameters
140
+ ----------
141
+ source_id : str
142
+ OCID of the Data Science model or model deployment.
143
+
144
+ Returns
145
+ -------
146
+ Union[ModelDeployment, DataScienceModel]
147
+ The corresponding resource object.
139
148
149
+ Raises
150
+ ------
151
+ AquaValueError
152
+ If the OCID is invalid or unsupported.
153
+ """
154
+ logger .debug (f"Resolving source for ID: { source_id } " )
155
+ if not is_valid_ocid (source_id ):
156
+ logger .error (f"Invalid OCID format: { source_id } " )
157
+ raise AquaValueError (
158
+ f"Invalid source ID: { source_id } . Please provide a valid model or model deployment OCID."
159
+ )
160
+
161
+ if "datasciencemodeldeployment" in source_id :
162
+ logger .debug (f"Identified as ModelDeployment OCID: { source_id } " )
163
+ return ModelDeployment .from_id (source_id )
164
+
165
+ if "datasciencemodel" in source_id :
166
+ logger .debug (f"Identified as DataScienceModel OCID: { source_id } " )
167
+ return DataScienceModel .from_id (source_id )
168
+
169
+ logger .error (f"Unrecognized OCID type: { source_id } " )
140
170
raise AquaValueError (
141
- f"Invalid source { source_id } . "
142
- "Specify either a model or model deployment id."
171
+ f"Unsupported source ID type: { source_id } . Must be a model or model deployment OCID."
143
172
)
144
173
174
+ def get_multi_source (
175
+ self ,
176
+ ids : List [str ],
177
+ ) -> Dict [str , Union [ModelDeployment , DataScienceModel ]]:
178
+ """
179
+ Retrieves multiple DataScience resources concurrently.
180
+
181
+ Parameters
182
+ ----------
183
+ ids : List[str]
184
+ A list of DataScience OCIDs.
185
+
186
+ Returns
187
+ -------
188
+ Dict[str, Union[ModelDeployment, DataScienceModel]]
189
+ A mapping from OCID to the corresponding resolved resource object.
190
+ """
191
+ logger .debug (f"Fetching { ids } sources in parallel." )
192
+ with ThreadPoolExecutor (max_workers = self .MAX_WORKERS ) as executor :
193
+ results = list (executor .map (self .get_source , ids ))
194
+
195
+ return dict (zip (ids , results ))
196
+
145
197
# TODO: refactor model evaluation implementation to use it.
146
198
@staticmethod
147
199
def create_model_version_set (
@@ -284,8 +336,11 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
284
336
logger .info (f"Artifact not found in model { model_id } ." )
285
337
return False
286
338
339
+ @cached (cache = TTLCache (maxsize = 5 , ttl = timedelta (minutes = 1 ), timer = datetime .now ))
287
340
def get_config_from_metadata (
288
- self , model_id : str , metadata_key : str
341
+ self ,
342
+ model_id : str ,
343
+ metadata_key : str ,
289
344
) -> ModelConfigResult :
290
345
"""Gets the config for the given Aqua model from model catalog metadata content.
291
346
@@ -300,8 +355,9 @@ def get_config_from_metadata(
300
355
ModelConfigResult
301
356
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
302
357
"""
303
- config = {}
358
+ config : Dict [ str , Any ] = {}
304
359
oci_model = self .ds_client .get_model (model_id ).data
360
+
305
361
try :
306
362
config = self .ds_client .get_model_defined_metadatum_artifact_content (
307
363
model_id , metadata_key
@@ -321,7 +377,7 @@ def get_config_from_metadata(
321
377
)
322
378
return ModelConfigResult (config = config , model_details = oci_model )
323
379
324
- @cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (minutes = 1 ), timer = datetime .now ))
380
+ @cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (minutes = 5 ), timer = datetime .now ))
325
381
def get_config (
326
382
self ,
327
383
model_id : str ,
@@ -346,8 +402,10 @@ def get_config(
346
402
ModelConfigResult
347
403
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
348
404
"""
349
- config_folder = config_folder or ConfigFolder . CONFIG
405
+ config : Dict [ str , Any ] = {}
350
406
oci_model = self .ds_client .get_model (model_id ).data
407
+
408
+ config_folder = config_folder or ConfigFolder .CONFIG
351
409
oci_aqua = (
352
410
(
353
411
Tags .AQUA_TAG in oci_model .freeform_tags
@@ -357,9 +415,9 @@ def get_config(
357
415
else False
358
416
)
359
417
if not oci_aqua :
360
- raise AquaRuntimeError (f"Target model { oci_model .id } is not an Aqua model." )
418
+ logger .debug (f"Target model { oci_model .id } is not an Aqua model." )
419
+ return ModelConfigResult (config = config , model_details = oci_model )
361
420
362
- config : Dict [str , Any ] = {}
363
421
artifact_path = get_artifact_path (oci_model .custom_metadata_list )
364
422
if not artifact_path :
365
423
logger .debug (
0 commit comments