Skip to content

Commit 8b63630

Browse files
committed
Merge branch 'main' of https://github.com/oracle/accelerated-data-science into feature/model_group
2 parents 41e6d6a + dc1f21b commit 8b63630

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+3692
-1597
lines changed

ads/aqua/app.py

Lines changed: 73 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import os
77
import traceback
8+
from concurrent.futures import ThreadPoolExecutor
89
from dataclasses import fields
910
from datetime import datetime, timedelta
1011
from itertools import chain
@@ -22,7 +23,7 @@
2223
from ads.aqua import logger
2324
from ads.aqua.common.entities import ModelConfigResult
2425
from ads.aqua.common.enums import ConfigFolder, Tags
25-
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
26+
from ads.aqua.common.errors import AquaValueError
2627
from ads.aqua.common.utils import (
2728
_is_valid_mvs,
2829
get_artifact_path,
@@ -58,6 +59,8 @@
5859
class AquaApp:
5960
"""Base Aqua App to contain common components."""
6061

62+
MAX_WORKERS = 10 # Number of workers for asynchronous resource loading
63+
6164
@telemetry(name="aqua")
6265
def __init__(self) -> None:
6366
if OCI_RESOURCE_PRINCIPAL_VERSION:
@@ -128,20 +131,69 @@ def update_model_provenance(
128131
update_model_provenance_details=update_model_provenance_details,
129132
)
130133

131-
# TODO: refactor model evaluation implementation to use it.
132134
@staticmethod
133135
def get_source(source_id: str) -> Union[ModelDeployment, DataScienceModel]:
134-
if is_valid_ocid(source_id):
135-
if "datasciencemodeldeployment" in source_id:
136-
return ModelDeployment.from_id(source_id)
137-
elif "datasciencemodel" in source_id:
138-
return DataScienceModel.from_id(source_id)
136+
"""
137+
Fetches a model or model deployment based on the provided OCID.
138+
139+
Parameters
140+
----------
141+
source_id : str
142+
OCID of the Data Science model or model deployment.
143+
144+
Returns
145+
-------
146+
Union[ModelDeployment, DataScienceModel]
147+
The corresponding resource object.
139148
149+
Raises
150+
------
151+
AquaValueError
152+
If the OCID is invalid or unsupported.
153+
"""
154+
logger.debug(f"Resolving source for ID: {source_id}")
155+
if not is_valid_ocid(source_id):
156+
logger.error(f"Invalid OCID format: {source_id}")
157+
raise AquaValueError(
158+
f"Invalid source ID: {source_id}. Please provide a valid model or model deployment OCID."
159+
)
160+
161+
if "datasciencemodeldeployment" in source_id:
162+
logger.debug(f"Identified as ModelDeployment OCID: {source_id}")
163+
return ModelDeployment.from_id(source_id)
164+
165+
if "datasciencemodel" in source_id:
166+
logger.debug(f"Identified as DataScienceModel OCID: {source_id}")
167+
return DataScienceModel.from_id(source_id)
168+
169+
logger.error(f"Unrecognized OCID type: {source_id}")
140170
raise AquaValueError(
141-
f"Invalid source {source_id}. "
142-
"Specify either a model or model deployment id."
171+
f"Unsupported source ID type: {source_id}. Must be a model or model deployment OCID."
143172
)
144173

174+
def get_multi_source(
175+
self,
176+
ids: List[str],
177+
) -> Dict[str, Union[ModelDeployment, DataScienceModel]]:
178+
"""
179+
Retrieves multiple DataScience resources concurrently.
180+
181+
Parameters
182+
----------
183+
ids : List[str]
184+
A list of DataScience OCIDs.
185+
186+
Returns
187+
-------
188+
Dict[str, Union[ModelDeployment, DataScienceModel]]
189+
A mapping from OCID to the corresponding resolved resource object.
190+
"""
191+
logger.debug(f"Fetching {ids} sources in parallel.")
192+
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
193+
results = list(executor.map(self.get_source, ids))
194+
195+
return dict(zip(ids, results))
196+
145197
# TODO: refactor model evaluation implementation to use it.
146198
@staticmethod
147199
def create_model_version_set(
@@ -284,8 +336,11 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
284336
logger.info(f"Artifact not found in model {model_id}.")
285337
return False
286338

339+
@cached(cache=TTLCache(maxsize=5, ttl=timedelta(minutes=1), timer=datetime.now))
287340
def get_config_from_metadata(
288-
self, model_id: str, metadata_key: str
341+
self,
342+
model_id: str,
343+
metadata_key: str,
289344
) -> ModelConfigResult:
290345
"""Gets the config for the given Aqua model from model catalog metadata content.
291346
@@ -300,8 +355,9 @@ def get_config_from_metadata(
300355
ModelConfigResult
301356
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
302357
"""
303-
config = {}
358+
config: Dict[str, Any] = {}
304359
oci_model = self.ds_client.get_model(model_id).data
360+
305361
try:
306362
config = self.ds_client.get_model_defined_metadatum_artifact_content(
307363
model_id, metadata_key
@@ -321,7 +377,7 @@ def get_config_from_metadata(
321377
)
322378
return ModelConfigResult(config=config, model_details=oci_model)
323379

324-
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
380+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now))
325381
def get_config(
326382
self,
327383
model_id: str,
@@ -346,8 +402,10 @@ def get_config(
346402
ModelConfigResult
347403
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
348404
"""
349-
config_folder = config_folder or ConfigFolder.CONFIG
405+
config: Dict[str, Any] = {}
350406
oci_model = self.ds_client.get_model(model_id).data
407+
408+
config_folder = config_folder or ConfigFolder.CONFIG
351409
oci_aqua = (
352410
(
353411
Tags.AQUA_TAG in oci_model.freeform_tags
@@ -357,9 +415,9 @@ def get_config(
357415
else False
358416
)
359417
if not oci_aqua:
360-
raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
418+
logger.debug(f"Target model {oci_model.id} is not an Aqua model.")
419+
return ModelConfigResult(config=config, model_details=oci_model)
361420

362-
config: Dict[str, Any] = {}
363421
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
364422
if not artifact_path:
365423
logger.debug(

ads/aqua/cli.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,18 @@ def _validate_value(flag, value):
9494
"If you intend to chain a function call to the result, please separate the "
9595
"flag and the subsequent function call with separator `-`."
9696
)
97+
98+
@staticmethod
99+
def install():
100+
"""Install ADS Aqua Extension from wheel file. Set enviroment variable `AQUA_EXTENSTION_PATH` to change the wheel file path.
101+
102+
Return
103+
------
104+
int:
105+
Installatation status.
106+
"""
107+
import subprocess
108+
109+
wheel_file_path = os.environ.get("AQUA_EXTENSTION_PATH", "/ads/extension/adsjupyterlab_aqua_extension*.whl")
110+
status = subprocess.run(f"pip install {wheel_file_path}",shell=True)
111+
return status.check_returncode

ads/aqua/client/client.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,20 @@ class HttpxOCIAuth(httpx.Auth):
6161

6262
def __init__(self, signer: Optional[oci.signer.Signer] = None):
6363
"""
64-
Initialize the HttpxOCIAuth instance.
64+
Initializes the authentication handler with the given or default OCI signer.
6565
66-
Args:
67-
signer (oci.signer.Signer): The OCI signer to use for signing requests.
66+
Parameters
67+
----------
68+
signer : oci.signer.Signer, optional
69+
The OCI signer instance to use. If None, a default signer will be retrieved.
6870
"""
69-
70-
self.signer = signer or authutil.default_signer().get("signer")
71+
try:
72+
self.signer = signer or authutil.default_signer().get("signer")
73+
if not self.signer:
74+
raise ValueError("OCI signer could not be initialized.")
75+
except Exception as e:
76+
logger.error("Failed to initialize OCI signer: %s", e)
77+
raise
7178

7279
def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
7380
"""
@@ -80,21 +87,31 @@ def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
8087
httpx.Request: The signed HTTPX request.
8188
"""
8289
# Create a requests.Request object from the HTTPX request
83-
req = requests.Request(
84-
method=request.method,
85-
url=str(request.url),
86-
headers=dict(request.headers),
87-
data=request.content,
88-
)
89-
prepared_request = req.prepare()
90+
try:
91+
req = requests.Request(
92+
method=request.method,
93+
url=str(request.url),
94+
headers=dict(request.headers),
95+
data=request.content,
96+
)
97+
prepared_request = req.prepare()
98+
self.signer.do_request_sign(prepared_request)
99+
100+
# Replace headers on the original HTTPX request with signed headers
101+
request.headers.update(prepared_request.headers)
102+
logger.debug("Successfully signed request to %s", request.url)
90103

91-
# Sign the request using the OCI Signer
92-
self.signer.do_request_sign(prepared_request)
104+
# Fix for GET/DELETE requests that OCI Gateway expects with Content-Length
105+
if (
106+
request.method in ["GET", "DELETE"]
107+
and "content-length" not in request.headers
108+
):
109+
request.headers["content-length"] = "0"
93110

94-
# Update the original HTTPX request with the signed headers
95-
request.headers.update(prepared_request.headers)
111+
except Exception as e:
112+
logger.error("Failed to sign request to %s: %s", request.url, e)
113+
raise
96114

97-
# Proceed with the request
98115
yield request
99116

100117

@@ -330,8 +347,8 @@ def _prepare_headers(
330347
"Content-Type": "application/json",
331348
"Accept": "text/event-stream" if stream else "application/json",
332349
}
333-
if stream:
334-
default_headers["enable-streaming"] = "true"
350+
# if stream:
351+
# default_headers["enable-streaming"] = "true"
335352
if headers:
336353
default_headers.update(headers)
337354

@@ -495,7 +512,7 @@ def generate(
495512
prompt: str,
496513
payload: Optional[Dict[str, Any]] = None,
497514
headers: Optional[Dict[str, str]] = None,
498-
stream: bool = True,
515+
stream: bool = False,
499516
) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
500517
"""
501518
Generate text completion for the given prompt.
@@ -521,7 +538,7 @@ def chat(
521538
messages: List[Dict[str, Any]],
522539
payload: Optional[Dict[str, Any]] = None,
523540
headers: Optional[Dict[str, str]] = None,
524-
stream: bool = True,
541+
stream: bool = False,
525542
) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
526543
"""
527544
Perform a chat interaction with the model.

0 commit comments

Comments
 (0)