Skip to content

Commit 011a369

Browse files
authored
Merge branch 'main' into track_md_logs_for_error_logging
2 parents 4e97417 + dc1f21b commit 011a369

File tree

14 files changed

+551
-159
lines changed

14 files changed

+551
-159
lines changed

ads/aqua/app.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import os
77
import traceback
8+
from concurrent.futures import ThreadPoolExecutor
89
from dataclasses import fields
910
from datetime import datetime, timedelta
1011
from itertools import chain
@@ -58,6 +59,8 @@
5859
class AquaApp:
5960
"""Base Aqua App to contain common components."""
6061

62+
MAX_WORKERS = 10 # Number of workers for asynchronous resource loading
63+
6164
@telemetry(name="aqua")
6265
def __init__(self) -> None:
6366
if OCI_RESOURCE_PRINCIPAL_VERSION:
@@ -128,20 +131,69 @@ def update_model_provenance(
128131
update_model_provenance_details=update_model_provenance_details,
129132
)
130133

131-
# TODO: refactor model evaluation implementation to use it.
132134
@staticmethod
133135
def get_source(source_id: str) -> Union[ModelDeployment, DataScienceModel]:
134-
if is_valid_ocid(source_id):
135-
if "datasciencemodeldeployment" in source_id:
136-
return ModelDeployment.from_id(source_id)
137-
elif "datasciencemodel" in source_id:
138-
return DataScienceModel.from_id(source_id)
136+
"""
137+
Fetches a model or model deployment based on the provided OCID.
138+
139+
Parameters
140+
----------
141+
source_id : str
142+
OCID of the Data Science model or model deployment.
143+
144+
Returns
145+
-------
146+
Union[ModelDeployment, DataScienceModel]
147+
The corresponding resource object.
148+
149+
Raises
150+
------
151+
AquaValueError
152+
If the OCID is invalid or unsupported.
153+
"""
154+
logger.debug(f"Resolving source for ID: {source_id}")
155+
if not is_valid_ocid(source_id):
156+
logger.error(f"Invalid OCID format: {source_id}")
157+
raise AquaValueError(
158+
f"Invalid source ID: {source_id}. Please provide a valid model or model deployment OCID."
159+
)
160+
161+
if "datasciencemodeldeployment" in source_id:
162+
logger.debug(f"Identified as ModelDeployment OCID: {source_id}")
163+
return ModelDeployment.from_id(source_id)
139164

165+
if "datasciencemodel" in source_id:
166+
logger.debug(f"Identified as DataScienceModel OCID: {source_id}")
167+
return DataScienceModel.from_id(source_id)
168+
169+
logger.error(f"Unrecognized OCID type: {source_id}")
140170
raise AquaValueError(
141-
f"Invalid source {source_id}. "
142-
"Specify either a model or model deployment id."
171+
f"Unsupported source ID type: {source_id}. Must be a model or model deployment OCID."
143172
)
144173

174+
def get_multi_source(
175+
self,
176+
ids: List[str],
177+
) -> Dict[str, Union[ModelDeployment, DataScienceModel]]:
178+
"""
179+
Retrieves multiple DataScience resources concurrently.
180+
181+
Parameters
182+
----------
183+
ids : List[str]
184+
A list of DataScience OCIDs.
185+
186+
Returns
187+
-------
188+
Dict[str, Union[ModelDeployment, DataScienceModel]]
189+
A mapping from OCID to the corresponding resolved resource object.
190+
"""
191+
logger.debug(f"Fetching {ids} sources in parallel.")
192+
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
193+
results = list(executor.map(self.get_source, ids))
194+
195+
return dict(zip(ids, results))
196+
145197
# TODO: refactor model evaluation implementation to use it.
146198
@staticmethod
147199
def create_model_version_set(

ads/aqua/client/client.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,20 @@ class HttpxOCIAuth(httpx.Auth):
6161

6262
def __init__(self, signer: Optional[oci.signer.Signer] = None):
6363
"""
64-
Initialize the HttpxOCIAuth instance.
64+
Initializes the authentication handler with the given or default OCI signer.
6565
66-
Args:
67-
signer (oci.signer.Signer): The OCI signer to use for signing requests.
66+
Parameters
67+
----------
68+
signer : oci.signer.Signer, optional
69+
The OCI signer instance to use. If None, a default signer will be retrieved.
6870
"""
69-
70-
self.signer = signer or authutil.default_signer().get("signer")
71+
try:
72+
self.signer = signer or authutil.default_signer().get("signer")
73+
if not self.signer:
74+
raise ValueError("OCI signer could not be initialized.")
75+
except Exception as e:
76+
logger.error("Failed to initialize OCI signer: %s", e)
77+
raise
7178

7279
def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
7380
"""
@@ -80,21 +87,31 @@ def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]:
8087
httpx.Request: The signed HTTPX request.
8188
"""
8289
# Create a requests.Request object from the HTTPX request
83-
req = requests.Request(
84-
method=request.method,
85-
url=str(request.url),
86-
headers=dict(request.headers),
87-
data=request.content,
88-
)
89-
prepared_request = req.prepare()
90+
try:
91+
req = requests.Request(
92+
method=request.method,
93+
url=str(request.url),
94+
headers=dict(request.headers),
95+
data=request.content,
96+
)
97+
prepared_request = req.prepare()
98+
self.signer.do_request_sign(prepared_request)
99+
100+
# Replace headers on the original HTTPX request with signed headers
101+
request.headers.update(prepared_request.headers)
102+
logger.debug("Successfully signed request to %s", request.url)
90103

91-
# Sign the request using the OCI Signer
92-
self.signer.do_request_sign(prepared_request)
104+
# Fix for GET/DELETE requests that OCI Gateway expects with Content-Length
105+
if (
106+
request.method in ["GET", "DELETE"]
107+
and "content-length" not in request.headers
108+
):
109+
request.headers["content-length"] = "0"
93110

94-
# Update the original HTTPX request with the signed headers
95-
request.headers.update(prepared_request.headers)
111+
except Exception as e:
112+
logger.error("Failed to sign request to %s: %s", request.url, e)
113+
raise
96114

97-
# Proceed with the request
98115
yield request
99116

100117

@@ -330,8 +347,8 @@ def _prepare_headers(
330347
"Content-Type": "application/json",
331348
"Accept": "text/event-stream" if stream else "application/json",
332349
}
333-
if stream:
334-
default_headers["enable-streaming"] = "true"
350+
# if stream:
351+
# default_headers["enable-streaming"] = "true"
335352
if headers:
336353
default_headers.update(headers)
337354

@@ -495,7 +512,7 @@ def generate(
495512
prompt: str,
496513
payload: Optional[Dict[str, Any]] = None,
497514
headers: Optional[Dict[str, str]] = None,
498-
stream: bool = True,
515+
stream: bool = False,
499516
) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
500517
"""
501518
Generate text completion for the given prompt.
@@ -521,7 +538,7 @@ def chat(
521538
messages: List[Dict[str, Any]],
522539
payload: Optional[Dict[str, Any]] = None,
523540
headers: Optional[Dict[str, str]] = None,
524-
stream: bool = True,
541+
stream: bool = False,
525542
) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]:
526543
"""
527544
Perform a chat interaction with the model.

ads/aqua/client/openai_client.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class ModelDeploymentBaseEndpoint(ExtendedEnum):
3232
"""Supported base endpoints for model deployments."""
3333

3434
PREDICT = "predict"
35-
PREDICT_WITH_RESPONSE_STREAM = "predictwithresponsestream"
35+
PREDICT_WITH_RESPONSE_STREAM = "predictWithResponseStream"
3636

3737

3838
class AquaOpenAIMixin:
@@ -51,9 +51,9 @@ def _patch_route(self, original_path: str) -> str:
5151
Returns:
5252
str: The normalized OpenAI-compatible route path (e.g., '/v1/chat/completions').
5353
"""
54-
normalized_path = original_path.lower().rstrip("/")
54+
normalized_path = original_path.rstrip("/")
5555

56-
match = re.search(r"/predict(withresponsestream)?", normalized_path)
56+
match = re.search(r"/predict(WithResponseStream)?", normalized_path)
5757
if not match:
5858
logger.debug("Route header cannot be resolved from path: %s", original_path)
5959
return ""
@@ -71,7 +71,7 @@ def _patch_route(self, original_path: str) -> str:
7171
"Route suffix does not start with a version prefix (e.g., '/v1'). "
7272
"This may lead to compatibility issues with OpenAI-style endpoints. "
7373
"Consider updating the URL to include a version prefix, "
74-
"such as '/predict/v1' or '/predictwithresponsestream/v1'."
74+
"such as '/predict/v1' or '/predictWithResponseStream/v1'."
7575
)
7676
# route_suffix = f"v1/{route_suffix}"
7777

@@ -124,13 +124,13 @@ def _patch_headers(self, request: httpx.Request) -> None:
124124

125125
def _patch_url(self) -> httpx.URL:
126126
"""
127-
Strips any suffixes from the base URL to retain only the `/predict` or `/predictwithresponsestream` path.
127+
Strips any suffixes from the base URL to retain only the `/predict` or `/predictWithResponseStream` path.
128128
129129
Returns:
130130
httpx.URL: The normalized base URL with the correct model deployment path.
131131
"""
132-
base_path = f"{self.base_url.path.lower().rstrip('/')}/"
133-
match = re.search(r"/predict(withresponsestream)?/", base_path)
132+
base_path = f"{self.base_url.path.rstrip('/')}/"
133+
match = re.search(r"/predict(WithResponseStream)?/", base_path)
134134
if match:
135135
trimmed = base_path[: match.end() - 1]
136136
return self.base_url.copy_with(path=trimmed)
@@ -144,7 +144,7 @@ def _prepare_request_common(self, request: httpx.Request) -> None:
144144
145145
This includes:
146146
- Patching headers with streaming and routing info.
147-
- Normalizing the URL path to include only `/predict` or `/predictwithresponsestream`.
147+
- Normalizing the URL path to include only `/predict` or `/predictWithResponseStream`.
148148
149149
Args:
150150
request (httpx.Request): The outgoing HTTPX request.
@@ -176,6 +176,7 @@ def __init__(
176176
http_client: Optional[httpx.Client] = None,
177177
http_client_kwargs: Optional[Dict[str, Any]] = None,
178178
_strict_response_validation: bool = False,
179+
patch_headers: bool = False,
179180
**kwargs: Any,
180181
) -> None:
181182
"""
@@ -196,6 +197,7 @@ def __init__(
196197
http_client (httpx.Client, optional): Custom HTTP client; if not provided, one will be auto-created.
197198
http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
198199
_strict_response_validation (bool, optional): Enable strict response validation.
200+
patch_headers (bool, optional): If True, redirects the requests by modifying the headers.
199201
**kwargs: Additional keyword arguments passed to the parent __init__.
200202
"""
201203
if http_client is None:
@@ -207,6 +209,8 @@ def __init__(
207209
logger.debug("API key not provided; using default placeholder for OCI.")
208210
api_key = "OCI"
209211

212+
self.patch_headers = patch_headers
213+
210214
super().__init__(
211215
api_key=api_key,
212216
organization=organization,
@@ -229,7 +233,8 @@ def _prepare_request(self, request: httpx.Request) -> None:
229233
Args:
230234
request (httpx.Request): The outgoing HTTP request.
231235
"""
232-
self._prepare_request_common(request)
236+
if self.patch_headers:
237+
self._prepare_request_common(request)
233238

234239

235240
class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
@@ -248,6 +253,7 @@ def __init__(
248253
http_client: Optional[httpx.Client] = None,
249254
http_client_kwargs: Optional[Dict[str, Any]] = None,
250255
_strict_response_validation: bool = False,
256+
patch_headers: bool = False,
251257
**kwargs: Any,
252258
) -> None:
253259
"""
@@ -269,6 +275,7 @@ def __init__(
269275
http_client (httpx.AsyncClient, optional): Custom asynchronous HTTP client; if not provided, one will be auto-created.
270276
http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
271277
_strict_response_validation (bool, optional): Enable strict response validation.
278+
patch_headers (bool, optional): If True, redirects the requests by modifying the headers.
272279
**kwargs: Additional keyword arguments passed to the parent __init__.
273280
"""
274281
if http_client is None:
@@ -280,6 +287,8 @@ def __init__(
280287
logger.debug("API key not provided; using default placeholder for OCI.")
281288
api_key = "OCI"
282289

290+
self.patch_headers = patch_headers
291+
283292
super().__init__(
284293
api_key=api_key,
285294
organization=organization,
@@ -302,4 +311,5 @@ async def _prepare_request(self, request: httpx.Request) -> None:
302311
Args:
303312
request (httpx.Request): The outgoing HTTP request.
304313
"""
305-
self._prepare_request_common(request)
314+
if self.patch_headers:
315+
self._prepare_request_common(request)

0 commit comments

Comments
 (0)