Skip to content

Commit 178d922

Browse files
authored
Merge branch 'main' into feature/odsc-65115
2 parents 043fe73 + 5451c2c commit 178d922

File tree

21 files changed

+686
-314
lines changed

21 files changed

+686
-314
lines changed

.github/workflows/run-forecast-unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,6 @@ jobs:
5656
$CONDA/bin/conda init
5757
source /home/runner/.bashrc
5858
pip install -r test-requirements-operators.txt
59-
pip install "oracle-automlx[forecasting]>=24.4.0"
59+
pip install "oracle-automlx[forecasting]>=24.4.1"
6060
pip install pandas>=2.2.0
6161
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast

ads/aqua/common/enums.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5252
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
5353
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
5454
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
55+
56+
57+
class CustomInferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5558
AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"
5659

5760

ads/aqua/common/utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import random
1212
import re
1313
import shlex
14+
import shutil
1415
import subprocess
1516
from datetime import datetime, timedelta
1617
from functools import wraps
@@ -21,6 +22,8 @@
2122
import fsspec
2223
import oci
2324
from cachetools import TTLCache, cached
25+
from huggingface_hub.constants import HF_HUB_CACHE
26+
from huggingface_hub.file_download import repo_folder_name
2427
from huggingface_hub.hf_api import HfApi, ModelInfo
2528
from huggingface_hub.utils import (
2629
GatedRepoError,
@@ -821,6 +824,48 @@ def upload_folder(
821824
return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
822825

823826

827+
def cleanup_local_hf_model_artifact(
828+
model_name: str,
829+
local_dir: str = None,
830+
):
831+
"""
832+
Helper function that deletes local artifacts downloaded from Hugging Face to free up disk space.
833+
Parameters
834+
----------
835+
model_name (str): Name of the huggingface model
836+
local_dir (str): Local directory where the object is downloaded
837+
838+
"""
839+
if local_dir and os.path.exists(local_dir):
840+
model_dir = os.path.join(local_dir, model_name)
841+
model_dir = (
842+
os.path.dirname(model_dir)
843+
if "/" in model_name or os.sep in model_name
844+
else model_dir
845+
)
846+
shutil.rmtree(model_dir, ignore_errors=True)
847+
if os.path.exists(model_dir):
848+
logger.debug(
849+
f"Could not delete local model artifact directory: {model_dir}"
850+
)
851+
else:
852+
logger.debug(f"Deleted local model artifact directory: {model_dir}.")
853+
854+
hf_local_path = os.path.join(
855+
HF_HUB_CACHE, repo_folder_name(repo_id=model_name, repo_type="model")
856+
)
857+
shutil.rmtree(hf_local_path, ignore_errors=True)
858+
859+
if os.path.exists(hf_local_path):
860+
logger.debug(
861+
f"Could not clear the local Hugging Face cache directory {hf_local_path} for the model {model_name}."
862+
)
863+
else:
864+
logger.debug(
865+
f"Cleared contents of local Hugging Face cache directory {hf_local_path} for the model {model_name}."
866+
)
867+
868+
824869
def is_service_managed_container(container):
825870
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
826871

ads/aqua/extension/model_handler.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
from typing import Optional
@@ -8,6 +8,9 @@
88
from tornado.web import HTTPError
99

1010
from ads.aqua.common.decorator import handle_exceptions
11+
from ads.aqua.common.enums import (
12+
CustomInferenceContainerTypeFamily,
13+
)
1114
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1215
from ads.aqua.common.utils import (
1316
get_hf_model_info,
@@ -128,6 +131,10 @@ def post(self, *args, **kwargs): # noqa: ARG002
128131
download_from_hf = (
129132
str(input_data.get("download_from_hf", "false")).lower() == "true"
130133
)
134+
local_dir = input_data.get("local_dir")
135+
cleanup_model_cache = (
136+
str(input_data.get("cleanup_model_cache", "true")).lower() == "true"
137+
)
131138
inference_container_uri = input_data.get("inference_container_uri")
132139
allow_patterns = input_data.get("allow_patterns")
133140
ignore_patterns = input_data.get("ignore_patterns")
@@ -139,6 +146,8 @@ def post(self, *args, **kwargs): # noqa: ARG002
139146
model=model,
140147
os_path=os_path,
141148
download_from_hf=download_from_hf,
149+
local_dir=local_dir,
150+
cleanup_model_cache=cleanup_model_cache,
142151
inference_container=inference_container,
143152
finetuning_container=finetuning_container,
144153
compartment_id=compartment_id,
@@ -163,7 +172,9 @@ def put(self, id):
163172
raise HTTPError(400, Errors.NO_INPUT_DATA)
164173

165174
inference_container = input_data.get("inference_container")
175+
inference_container_uri = input_data.get("inference_container_uri")
166176
inference_containers = AquaModelApp.list_valid_inference_containers()
177+
inference_containers.extend(CustomInferenceContainerTypeFamily.values())
167178
if (
168179
inference_container is not None
169180
and inference_container not in inference_containers
@@ -176,7 +187,13 @@ def put(self, id):
176187
task = input_data.get("task")
177188
app = AquaModelApp()
178189
self.finish(
179-
app.edit_registered_model(id, inference_container, enable_finetuning, task)
190+
app.edit_registered_model(
191+
id,
192+
inference_container,
193+
inference_container_uri,
194+
enable_finetuning,
195+
task,
196+
)
180197
)
181198
app.clear_model_details_cache(model_id=id)
182199

ads/aqua/model/entities.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
"""
@@ -283,6 +283,7 @@ class ImportModelDetails(CLIBuilderMixin):
283283
os_path: str
284284
download_from_hf: Optional[bool] = True
285285
local_dir: Optional[str] = None
286+
cleanup_model_cache: Optional[bool] = True
286287
inference_container: Optional[str] = None
287288
finetuning_container: Optional[str] = None
288289
compartment_id: Optional[str] = None

ads/aqua/model/model.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
import os
55
import pathlib
@@ -15,6 +15,7 @@
1515
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
1616
from ads.aqua.app import AquaApp
1717
from ads.aqua.common.enums import (
18+
CustomInferenceContainerTypeFamily,
1819
FineTuningContainerTypeFamily,
1920
InferenceContainerTypeFamily,
2021
Tags,
@@ -23,6 +24,7 @@
2324
from ads.aqua.common.utils import (
2425
LifecycleStatus,
2526
_build_resource_identifier,
27+
cleanup_local_hf_model_artifact,
2628
copy_model_config,
2729
create_word_icon,
2830
generate_tei_cmd_var,
@@ -376,8 +378,10 @@ def delete_model(self, model_id):
376378
f"Failed to delete model:{model_id}. Only registered models or finetuned model can be deleted."
377379
)
378380

379-
@telemetry(entry_point="plugin=model&action=delete", name="aqua")
380-
def edit_registered_model(self, id, inference_container, enable_finetuning, task):
381+
@telemetry(entry_point="plugin=model&action=edit", name="aqua")
382+
def edit_registered_model(
383+
self, id, inference_container, inference_container_uri, enable_finetuning, task
384+
):
381385
"""Edits the default config of unverified registered model.
382386
383387
Parameters
@@ -386,6 +390,8 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
386390
The model OCID.
387391
inference_container: str.
388392
The inference container family name
393+
inference_container_uri: str
394+
The inference container uri for embedding models
389395
enable_finetuning: str
390396
Flag to enable or disable finetuning over the model. Defaults to None
391397
task:
@@ -401,19 +407,44 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
401407
if ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None):
402408
if ds_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None):
403409
raise AquaRuntimeError(
404-
f"Failed to edit model:{id}. Only registered unverified models can be edited."
410+
"Only registered unverified models can be edited."
405411
)
406412
else:
407413
custom_metadata_list = ds_model.custom_metadata_list
408414
freeform_tags = ds_model.freeform_tags
409415
if inference_container:
410-
custom_metadata_list.add(
411-
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
412-
value=inference_container,
413-
category=MetadataCustomCategory.OTHER,
414-
description="Deployment container mapping for SMC",
415-
replace=True,
416-
)
416+
if (
417+
inference_container in CustomInferenceContainerTypeFamily
418+
and inference_container_uri is None
419+
):
420+
raise AquaRuntimeError(
421+
"Inference container URI must be provided."
422+
)
423+
else:
424+
custom_metadata_list.add(
425+
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
426+
value=inference_container,
427+
category=MetadataCustomCategory.OTHER,
428+
description="Deployment container mapping for SMC",
429+
replace=True,
430+
)
431+
if inference_container_uri:
432+
if (
433+
inference_container in CustomInferenceContainerTypeFamily
434+
or inference_container is None
435+
):
436+
custom_metadata_list.add(
437+
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER_URI,
438+
value=inference_container_uri,
439+
category=MetadataCustomCategory.OTHER,
440+
description=f"Inference container URI for {ds_model.display_name}",
441+
replace=True,
442+
)
443+
else:
444+
raise AquaRuntimeError(
445+
f"Inference container URI can be edited only with container values: {CustomInferenceContainerTypeFamily.values()}"
446+
)
447+
417448
if enable_finetuning is not None:
418449
if enable_finetuning.lower() == "true":
419450
custom_metadata_list.add(
@@ -448,9 +479,7 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
448479
)
449480
AquaApp().update_model(id, update_model_details)
450481
else:
451-
raise AquaRuntimeError(
452-
f"Failed to edit model:{id}. Only registered unverified models can be edited."
453-
)
482+
raise AquaRuntimeError("Only registered unverified models can be edited.")
454483

455484
def _fetch_metric_from_metadata(
456485
self,
@@ -869,8 +898,7 @@ def _create_model_catalog_entry(
869898
# only add cmd vars if inference container is not an SMC
870899
if (
871900
inference_container not in smc_container_set
872-
and inference_container
873-
== InferenceContainerTypeFamily.AQUA_TEI_CONTAINER_FAMILY
901+
and inference_container in CustomInferenceContainerTypeFamily.values()
874902
):
875903
cmd_vars = generate_tei_cmd_var(os_path)
876904
metadata.add(
@@ -1322,20 +1350,20 @@ def _download_model_from_hf(
13221350
Returns
13231351
-------
13241352
model_artifact_path (str): Location where the model artifacts are downloaded.
1325-
13261353
"""
13271354
# Download the model from hub
1328-
if not local_dir:
1329-
local_dir = os.path.join(os.path.expanduser("~"), "cached-model")
1330-
local_dir = os.path.join(local_dir, model_name)
1331-
os.makedirs(local_dir, exist_ok=True)
1332-
snapshot_download(
1355+
if local_dir:
1356+
local_dir = os.path.join(local_dir, model_name)
1357+
os.makedirs(local_dir, exist_ok=True)
1358+
1359+
# if local_dir is not set, the return value points to the cached data folder
1360+
local_dir = snapshot_download(
13331361
repo_id=model_name,
13341362
local_dir=local_dir,
13351363
allow_patterns=allow_patterns,
13361364
ignore_patterns=ignore_patterns,
13371365
)
1338-
# Upload to object storage and skip .cache/huggingface/ folder
1366+
# Upload to object storage
13391367
model_artifact_path = upload_folder(
13401368
os_path=os_path,
13411369
local_dir=local_dir,
@@ -1365,6 +1393,8 @@ def register(
13651393
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
13661394
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
13671395
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
1396+
cleanup_model_cache (bool): Deletes downloaded files from local machine after model is successfully
1397+
registered. Set to True by default.
13681398
13691399
Returns:
13701400
AquaModel:
@@ -1474,6 +1504,14 @@ def register(
14741504
detail=validation_result.telemetry_model_name,
14751505
)
14761506

1507+
if (
1508+
import_model_details.download_from_hf
1509+
and import_model_details.cleanup_model_cache
1510+
):
1511+
cleanup_local_hf_model_artifact(
1512+
model_name=model_name, local_dir=import_model_details.local_dir
1513+
)
1514+
14771515
return AquaModel(**aqua_model_attributes)
14781516

14791517
def _if_show(self, model: DataScienceModel) -> bool:

ads/opctl/config/merger.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,33 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

4-
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
import os
87
from string import Template
98
from typing import Dict
10-
import json
119

1210
import yaml
1311

1412
from ads.common.auth import AuthType, ResourcePrincipal
1513
from ads.opctl import logger
1614
from ads.opctl.config.base import ConfigProcessor
17-
from ads.opctl.config.utils import read_from_ini, _DefaultNoneDict
18-
from ads.opctl.utils import is_in_notebook_session, get_service_pack_prefix
15+
from ads.opctl.config.utils import _DefaultNoneDict, read_from_ini
1916
from ads.opctl.constants import (
20-
DEFAULT_PROFILE,
21-
DEFAULT_OCI_CONFIG_FILE,
22-
DEFAULT_CONDA_PACK_FOLDER,
23-
DEFAULT_ADS_CONFIG_FOLDER,
24-
ADS_JOBS_CONFIG_FILE_NAME,
2517
ADS_CONFIG_FILE_NAME,
26-
ADS_ML_PIPELINE_CONFIG_FILE_NAME,
2718
ADS_DATAFLOW_CONFIG_FILE_NAME,
19+
ADS_JOBS_CONFIG_FILE_NAME,
2820
ADS_LOCAL_BACKEND_CONFIG_FILE_NAME,
21+
ADS_ML_PIPELINE_CONFIG_FILE_NAME,
2922
ADS_MODEL_DEPLOYMENT_CONFIG_FILE_NAME,
30-
DEFAULT_NOTEBOOK_SESSION_CONDA_DIR,
3123
BACKEND_NAME,
24+
DEFAULT_ADS_CONFIG_FOLDER,
25+
DEFAULT_CONDA_PACK_FOLDER,
26+
DEFAULT_NOTEBOOK_SESSION_CONDA_DIR,
27+
DEFAULT_OCI_CONFIG_FILE,
28+
DEFAULT_PROFILE,
3229
)
30+
from ads.opctl.utils import get_service_pack_prefix, is_in_notebook_session
3331

3432

3533
class ConfigMerger(ConfigProcessor):
@@ -41,8 +39,9 @@ class ConfigMerger(ConfigProcessor):
4139
"""
4240

4341
def process(self, **kwargs) -> None:
44-
config_string = Template(json.dumps(self.config)).safe_substitute(os.environ)
45-
self.config = json.loads(config_string)
42+
for key, value in self.config.items():
43+
if isinstance(value, str): # Substitute only if the value is a string
44+
self.config[key] = Template(value).safe_substitute(os.environ)
4645

4746
if "runtime" not in self.config:
4847
self.config["runtime"] = {}

0 commit comments

Comments
 (0)