Skip to content

Commit 895112c

Browse files
committed
Merge branch 'main' of https://github.com/oracle/accelerated-data-science into model_group
2 parents 99ef319 + 33c9966 commit 895112c

File tree

28 files changed

+659
-316
lines changed

28 files changed

+659
-316
lines changed

ads/aqua/app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(self) -> None:
6464
set_auth("resource_principal")
6565
self._auth = default_signer({"service_endpoint": OCI_ODSC_SERVICE_ENDPOINT})
6666
self.ds_client = oc.OCIClientFactory(**self._auth).data_science
67+
self.compute_client = oc.OCIClientFactory(**default_signer()).compute
6768
self.logging_client = oc.OCIClientFactory(**default_signer()).logging_management
6869
self.identity_client = oc.OCIClientFactory(**default_signer()).identity
6970
self.region = extract_region(self._auth)

ads/aqua/common/entities.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ class AquaMultiModelRef(Serializable):
157157
Optional environment variables to override during deployment.
158158
artifact_location : Optional[str]
159159
Artifact path of model in the multimodel group.
160+
fine_tune_weights_location : Optional[str]
161+
For fine tuned models, the artifact path of the modified model weights
160162
"""
161163

162164
model_id: str = Field(..., description="The model OCID to deploy.")
@@ -171,6 +173,9 @@ class AquaMultiModelRef(Serializable):
171173
artifact_location: Optional[str] = Field(
172174
None, description="Artifact path of model in the multimodel group."
173175
)
176+
fine_tune_weights_location: Optional[str] = Field(
177+
None, description="For fine tuned models, the artifact path of the modified model weights"
178+
)
174179

175180
class Config:
176181
extra = "ignore"

ads/aqua/common/enums.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ class Resource(ExtendedEnum):
2020
MODEL_VERSION_SET = "model-version-sets"
2121

2222

23+
class PredictEndpoints(ExtendedEnum):
24+
CHAT_COMPLETIONS_ENDPOINT = "/v1/chat/completions"
25+
TEXT_COMPLETIONS_ENDPOINT = "/v1/completions"
26+
EMBEDDING_ENDPOINT = "/v1/embedding"
27+
28+
2329
class Tags(ExtendedEnum):
2430
TASK = "task"
2531
LICENSE = "license"
@@ -49,6 +55,7 @@ class InferenceContainerType(ExtendedEnum):
4955
class InferenceContainerTypeFamily(ExtendedEnum):
5056
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
5157
AQUA_VLLM_V1_CONTAINER_FAMILY = "odsc-vllm-serving-v1"
58+
AQUA_VLLM_LLAMA4_CONTAINER_FAMILY = "odsc-vllm-serving-llama4"
5259
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
5360
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
5461

@@ -119,4 +126,9 @@ class Platform(ExtendedEnum):
119126
InferenceContainerTypeFamily.AQUA_VLLM_V1_CONTAINER_FAMILY,
120127
InferenceContainerTypeFamily.AQUA_VLLM_CONTAINER_FAMILY,
121128
],
129+
InferenceContainerTypeFamily.AQUA_VLLM_LLAMA4_CONTAINER_FAMILY: [
130+
InferenceContainerTypeFamily.AQUA_VLLM_LLAMA4_CONTAINER_FAMILY,
131+
InferenceContainerTypeFamily.AQUA_VLLM_V1_CONTAINER_FAMILY,
132+
InferenceContainerTypeFamily.AQUA_VLLM_CONTAINER_FAMILY,
133+
],
122134
}

ads/aqua/common/utils.py

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,9 @@ def get_params_dict(params: Union[str, List[str]]) -> dict:
832832
"""
833833
params_list = get_params_list(params) if isinstance(params, str) else params
834834
return {
835-
split_result[0]: split_result[1] if len(split_result) > 1 else UNKNOWN
835+
split_result[0]: " ".join(split_result[1:])
836+
if len(split_result) > 1
837+
else UNKNOWN
836838
for split_result in (x.split() for x in params_list)
837839
}
838840

@@ -881,7 +883,9 @@ def build_params_string(params: dict) -> str:
881883
A params string.
882884
"""
883885
return (
884-
" ".join(f"{name} {value}" for name, value in params.items()).strip()
886+
" ".join(
887+
f"{name} {value}" if value else f"{name}" for name, value in params.items()
888+
).strip()
885889
if params
886890
else UNKNOWN
887891
)
@@ -1158,9 +1162,11 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
11581162

11591163

11601164
def build_pydantic_error_message(ex: ValidationError):
1161-
"""Added to handle error messages from pydantic model validator.
1165+
"""
1166+
Added to handle error messages from pydantic model validator.
11621167
Combine both loc and msg for errors where loc (field) is present in error details, else only build error
1163-
message using msg field."""
1168+
message using msg field.
1169+
"""
11641170

11651171
return {
11661172
".".join(map(str, e["loc"])): e["msg"]
@@ -1185,67 +1191,71 @@ def is_pydantic_model(obj: object) -> bool:
11851191

11861192
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now))
11871193
def load_gpu_shapes_index(
1188-
auth: Optional[Dict] = None,
1194+
auth: Optional[Dict[str, Any]] = None,
11891195
) -> GPUShapesIndex:
11901196
"""
1191-
Loads the GPU shapes index from Object Storage or a local resource folder.
1197+
Load the GPU shapes index, preferring the OS bucket copy over the local one.
11921198
1193-
The function first attempts to load the file from an Object Storage bucket using fsspec.
1194-
If the loading fails (due to connection issues, missing file, etc.), it falls back to
1195-
loading the index from a local file.
1199+
Attempts to read `gpu_shapes_index.json` from OCI Object Storage first;
1200+
if that succeeds, those entries will override the local defaults.
11961201
11971202
Parameters
11981203
----------
1199-
auth: (Dict, optional). Defaults to None.
1200-
The default authentication is set using `ads.set_auth` API. If you need to override the
1201-
default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
1202-
authentication signer and kwargs required to instantiate IdentityClient object.
1204+
auth
1205+
Optional auth dict (as returned by `ads.common.auth.default_signer()`)
1206+
to pass through to `fsspec.open()`.
12031207
12041208
Returns
12051209
-------
1206-
GPUShapesIndex: The parsed GPU shapes index.
1210+
GPUShapesIndex
1211+
Merged index where any shape present remotely supersedes the local entry.
12071212
12081213
Raises
12091214
------
1210-
FileNotFoundError: If the GPU shapes index cannot be found in either Object Storage or locally.
1211-
json.JSONDecodeError: If the JSON is malformed.
1215+
json.JSONDecodeError
1216+
If any of the JSON is malformed.
12121217
"""
12131218
file_name = "gpu_shapes_index.json"
1214-
data: Dict[str, Any] = {}
12151219

1216-
# Check if the CONDA_BUCKET_NS environment variable is set.
1220+
# Try remote load
1221+
remote_data: Dict[str, Any] = {}
12171222
if CONDA_BUCKET_NS:
12181223
try:
12191224
auth = auth or authutil.default_signer()
1220-
# Construct the object storage path. Adjust bucket name and path as needed.
12211225
storage_path = (
12221226
f"oci://{CONDA_BUCKET_NAME}@{CONDA_BUCKET_NS}/service_pack/{file_name}"
12231227
)
1224-
logger.debug("Loading GPU shapes index from Object Storage")
1225-
with fsspec.open(storage_path, mode="r", **auth) as file_obj:
1226-
data = json.load(file_obj)
1227-
logger.debug("Successfully loaded GPU shapes index.")
1228-
except Exception as ex:
12291228
logger.debug(
1230-
f"Failed to load GPU shapes index from Object Storage. Details: {ex}"
1231-
)
1232-
1233-
# If loading from Object Storage failed, load from the local resource folder.
1234-
if not data:
1235-
try:
1236-
local_path = os.path.join(
1237-
os.path.dirname(__file__), "../resources", file_name
1229+
"Loading GPU shapes index from Object Storage: %s", storage_path
12381230
)
1239-
logger.debug(f"Loading GPU shapes index from {local_path}.")
1240-
with open(local_path) as file_obj:
1241-
data = json.load(file_obj)
1242-
logger.debug("Successfully loaded GPU shapes index.")
1243-
except Exception as e:
1231+
with fsspec.open(storage_path, mode="r", **auth) as f:
1232+
remote_data = json.load(f)
12441233
logger.debug(
1245-
f"Failed to load GPU shapes index from {local_path}. Details: {e}"
1234+
"Loaded %d shapes from Object Storage",
1235+
len(remote_data.get("shapes", {})),
12461236
)
1237+
except Exception as ex:
1238+
logger.debug("Remote load failed (%s); falling back to local", ex)
1239+
1240+
# Load local copy
1241+
local_data: Dict[str, Any] = {}
1242+
local_path = os.path.join(os.path.dirname(__file__), "../resources", file_name)
1243+
try:
1244+
logger.debug("Loading GPU shapes index from local file: %s", local_path)
1245+
with open(local_path) as f:
1246+
local_data = json.load(f)
1247+
logger.debug(
1248+
"Loaded %d shapes from local file", len(local_data.get("shapes", {}))
1249+
)
1250+
except Exception as ex:
1251+
logger.debug("Local load GPU shapes index failed (%s)", ex)
1252+
1253+
# Merge: remote shapes override local
1254+
local_shapes = local_data.get("shapes", {})
1255+
remote_shapes = remote_data.get("shapes", {})
1256+
merged_shapes = {**local_shapes, **remote_shapes}
12471257

1248-
return GPUShapesIndex(**data)
1258+
return GPUShapesIndex(shapes=merged_shapes)
12491259

12501260

12511261
def get_preferred_compatible_family(selected_families: set[str]) -> str:

ads/aqua/config/container_config.py

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from oci.data_science.models import ContainerSummary
88
from pydantic import Field
99

10+
from ads.aqua import logger
1011
from ads.aqua.config.utils.serializer import Serializable
1112
from ads.aqua.constants import (
1213
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
@@ -168,50 +169,47 @@ def from_service_config(
168169
container_type = container.family_name
169170
usages = [x.upper() for x in container.usages]
170171
if "INFERENCE" in usages or "MULTI_MODEL" in usages:
172+
# Extract additional configurations
173+
additional_configurations = {}
174+
try:
175+
additional_configurations = (
176+
container.workload_configuration_details_list[
177+
0
178+
].additional_configurations
179+
)
180+
except (AttributeError, IndexError) as ex:
181+
logger.debug(
182+
"Failed to extract `additional_configurations` for container '%s': %s",
183+
getattr(container, "container_name", "<unknown>"),
184+
ex,
185+
)
186+
171187
container_item.platforms.append(
172-
container.workload_configuration_details_list[
173-
0
174-
].additional_configurations.get("platforms")
188+
additional_configurations.get("platforms")
175189
)
176190
container_item.model_formats.append(
177-
container.workload_configuration_details_list[
178-
0
179-
].additional_configurations.get("modelFormats")
191+
additional_configurations.get("modelFormats")
180192
)
193+
194+
# Parse environment variables from `additional_configurations`.
195+
# Only keys present in the configuration will be added to the result.
196+
config_keys = {
197+
"MODEL_DEPLOY_PREDICT_ENDPOINT": UNKNOWN,
198+
"MODEL_DEPLOY_HEALTH_ENDPOINT": UNKNOWN,
199+
"MODEL_DEPLOY_ENABLE_STREAMING": UNKNOWN,
200+
"PORT": UNKNOWN,
201+
"HEALTH_CHECK_PORT": UNKNOWN,
202+
"VLLM_USE_V1": UNKNOWN,
203+
}
204+
181205
env_vars = [
182-
{
183-
"MODEL_DEPLOY_PREDICT_ENDPOINT": container.workload_configuration_details_list[
184-
0
185-
].additional_configurations.get(
186-
"MODEL_DEPLOY_PREDICT_ENDPOINT", UNKNOWN
187-
)
188-
},
189-
{
190-
"MODEL_DEPLOY_HEALTH_ENDPOINT": container.workload_configuration_details_list[
191-
0
192-
].additional_configurations.get(
193-
"MODEL_DEPLOY_HEALTH_ENDPOINT", UNKNOWN
194-
)
195-
},
196-
{
197-
"MODEL_DEPLOY_ENABLE_STREAMING": container.workload_configuration_details_list[
198-
0
199-
].additional_configurations.get(
200-
"MODEL_DEPLOY_ENABLE_STREAMING", UNKNOWN
201-
)
202-
},
203-
{
204-
"PORT": container.workload_configuration_details_list[
205-
0
206-
].additional_configurations.get("PORT", "")
207-
},
208-
{
209-
"HEALTH_CHECK_PORT": container.workload_configuration_details_list[
210-
0
211-
].additional_configurations.get("HEALTH_CHECK_PORT", UNKNOWN),
212-
},
206+
{key: additional_configurations.get(key, default)}
207+
for key, default in config_keys.items()
208+
if key in additional_configurations
213209
]
214-
container_spec = AquaContainerConfigSpec(
210+
211+
# Build container spec
212+
container_item.spec = AquaContainerConfigSpec(
215213
cli_param=container.workload_configuration_details_list[0].cmd,
216214
server_port=str(
217215
container.workload_configuration_details_list[0].server_port
@@ -236,13 +234,14 @@ def from_service_config(
236234
)
237235
),
238236
)
239-
container_item.spec = container_spec
237+
240238
if "INFERENCE" in usages or "MULTI_MODEL" in usages:
241239
inference_items[container_type] = container_item
242240
if "FINE_TUNE" in usages:
243241
finetune_items[container_type] = container_item
244242
if "EVALUATION" in usages:
245243
evaluate_items[container_type] = container_item
244+
246245
return cls(
247246
inference=inference_items, finetune=finetune_items, evaluate=evaluate_items
248247
)

0 commit comments

Comments
 (0)