Skip to content

Commit f6c2c57

Browse files
committed
Updated pr and added unit tests.
1 parent 895112c commit f6c2c57

File tree

4 files changed

+583
-42
lines changed

4 files changed

+583
-42
lines changed

ads/model/datascience_model_group.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,23 @@
2121
MemberModelEntries,
2222
ModelGroup,
2323
ModelGroupDetails,
24+
ModelGroupSummary,
2425
UpdateModelGroupDetails,
2526
)
26-
except ModuleNotFoundError:
27+
except ModuleNotFoundError as err:
2728
raise ModuleNotFoundError(
2829
"The oci model group module was not found. Please run `pip install oci` "
2930
"to install the latest oci sdk."
30-
)
31+
) from err
3132

3233
DEFAULT_WAIT_TIME = 1200
3334
DEFAULT_POLL_INTERVAL = 10
3435
ALLOWED_CREATE_TYPES = ["CREATE", "CLONE"]
36+
MODEL_GROUP_KIND = "datascienceModelGroup"
3537

3638

3739
class DataScienceModelGroup(Builder):
38-
"""Represents a Data Science Model.
40+
"""Represents a Data Science Model Group.
3941
4042
Attributes
4143
----------
@@ -198,6 +200,11 @@ def __init__(self, spec=None, **kwargs):
198200
super().__init__(spec, **kwargs)
199201
self.dsc_model_group = OCIDataScienceModelGroup()
200202

203+
@property
204+
def kind(self) -> str:
205+
"""The kind of the model group as showing in a YAML."""
206+
return MODEL_GROUP_KIND
207+
201208
@property
202209
def id(self) -> str:
203210
"""The model group OCID."""
@@ -503,6 +510,7 @@ def create(
503510
return self._update_from_oci_model(response)
504511

505512
def _build_model_group_details(self) -> dict:
513+
"""Builds model group details dict for creating or updating oci model group."""
506514
model_group_details = HomogeneousModelGroupDetails(
507515
custom_metadata_list=[
508516
CustomMetadata(
@@ -537,8 +545,21 @@ def _build_model_group_details(self) -> dict:
537545
return build_model_group_details
538546

539547
def _update_from_oci_model(
540-
self, oci_model_group_instance: ModelGroup
548+
self, oci_model_group_instance: Union[ModelGroup, ModelGroupSummary]
541549
) -> "DataScienceModelGroup":
550+
"""Updates self spec from oci model group instance.
551+
552+
Parameters
553+
----------
554+
oci_model_group_instance: Union[ModelGroup, ModelGroupSummary]
555+
The oci model group instance, could be an instance of oci.data_science.models.ModelGroup
556+
or oci.data_science.models.ModelGroupSummary.
557+
558+
Returns
559+
-------
560+
DataScienceModelGroup
561+
The instance of DataScienceModelGroup.
562+
"""
542563
self.dsc_model_group = oci_model_group_instance
543564
for key, value in self.attribute_map.items():
544565
if hasattr(oci_model_group_instance, value):
@@ -560,23 +581,27 @@ def _update_from_oci_model(
560581
)
561582
self.set_spec(self.CONST_CUSTOM_METADATA_LIST, model_custom_metadata)
562583

563-
member_model_entries: MemberModelEntries = (
564-
oci_model_group_instance.member_model_entries
565-
)
566-
member_model_details: List[MemberModelDetails] = (
567-
member_model_entries.member_model_details
568-
)
584+
# only updates member_models when oci_model_group_instance is an instance of
585+
# oci.data_science.models.ModelGroup as oci.data_science.models.ModelGroupSummary
586+
# doesn't have member_model_entries property.
587+
if isinstance(oci_model_group_instance, ModelGroup):
588+
member_model_entries: MemberModelEntries = (
589+
oci_model_group_instance.member_model_entries
590+
)
591+
member_model_details: List[MemberModelDetails] = (
592+
member_model_entries.member_model_details
593+
)
569594

570-
self.set_spec(
571-
self.CONST_MEMBER_MODELS,
572-
[
573-
{
574-
"inference_key": member_model_detail.inference_key,
575-
"model_id": member_model_detail.model_id,
576-
}
577-
for member_model_detail in member_model_details
578-
],
579-
)
595+
self.set_spec(
596+
self.CONST_MEMBER_MODELS,
597+
[
598+
{
599+
"inference_key": member_model_detail.inference_key,
600+
"model_id": member_model_detail.model_id,
601+
}
602+
for member_model_detail in member_model_details
603+
],
604+
)
580605

581606
return self
582607

@@ -729,18 +754,18 @@ def sync(self) -> "DataScienceModelGroup":
729754
@classmethod
730755
def list(
731756
cls,
757+
status: str = None,
732758
compartment_id: str = None,
733-
project_id: str = None,
734759
**kwargs,
735760
) -> List["DataScienceModelGroup"]:
736761
"""Lists datascience model groups in a given compartment.
737762
738763
Parameters
739764
----------
765+
status: (str, optional). Defaults to `None`.
766+
The status of model group. Allowed values: `ACTIVE`, `CREATING`, `DELETED`, `DELETING`, `FAILED` and `INACTIVE`.
740767
compartment_id: (str, optional). Defaults to `None`.
741768
The compartment OCID.
742-
project_id: (str, optional). Defaults to `None`.
743-
The project OCID.
744769
kwargs
745770
Additional keyword arguments for filtering model groups.
746771
@@ -750,9 +775,11 @@ def list(
750775
The list of the datascience model groups.
751776
"""
752777
return [
753-
cls()._update_from_oci_model(model_group)
754-
for model_group in OCIDataScienceModelGroup.list_resource(
755-
compartment_id, project_id=project_id, **kwargs
778+
cls()._update_from_oci_model(model_group_summary)
779+
for model_group_summary in OCIDataScienceModelGroup.list(
780+
status=status,
781+
compartment_id=compartment_id,
782+
**kwargs,
756783
)
757784
]
758785

ads/model/service/oci_datascience_model_group.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111

1212
from ads.common.oci_datascience import OCIDataScienceMixin
1313
from ads.common.work_request import DataScienceWorkRequest
14-
from ads.config import PROJECT_OCID
1514
from ads.model.deployment.common.utils import OCIClientManager, State
1615

1716
try:
1817
from oci.data_science.models import CreateModelGroupDetails, UpdateModelGroupDetails
19-
except:
20-
raise
18+
except ModuleNotFoundError as err:
19+
raise ModuleNotFoundError(
20+
"The oci model group module was not found. Please run `pip install oci` "
21+
"to install the latest oci sdk."
22+
) from err
2123

2224
logger = logging.getLogger(__name__)
2325

@@ -116,8 +118,8 @@ class OCIDataScienceModelGroup(
116118
Deletes datascience model group.
117119
update(self, ...) -> "OCIDataScienceModelGroup":
118120
Updates datascience model group.
119-
list(self, ...) -> list[oci.data_science.models.ModelGroup]:
120-
List oci.data_science.models.ModelGroup instances within given compartment and project.
121+
list(self, ...) -> list[oci.data_science.models.ModelGroupSummary]:
122+
List oci.data_science.models.ModelGroupSummary instances within given compartment.
121123
from_id(cls, model_group: str) -> "OCIDataScienceModelGroup":
122124
Gets model group by OCID.
123125
@@ -396,8 +398,8 @@ def update(
396398
"""
397399
if wait_for_completion:
398400
wait_for_states = [
399-
oci.data_science.models.WorkRequest.STATUS_SUCCEEDED,
400-
oci.data_science.models.WorkRequest.STATUS_FAILED,
401+
self.LIFECYCLE_STATE_ACTIVE,
402+
self.LIFECYCLE_STATE_FAILED,
401403
]
402404
else:
403405
wait_for_states = []
@@ -423,7 +425,6 @@ def list(
423425
cls,
424426
status: str = None,
425427
compartment_id: str = None,
426-
project_id: str = None,
427428
**kwargs,
428429
) -> list:
429430
"""Lists the model group associated with current compartment id and status
@@ -438,16 +439,13 @@ def list(
438439
Defaults to the compartment set in the environment variable "NB_SESSION_COMPARTMENT_OCID".
439440
If "NB_SESSION_COMPARTMENT_OCID" is not set, the root compartment ID will be used.
440441
An ValueError will be raised if root compartment ID cannot be determined.
441-
project_id : str
442-
Target project to list model groups from.
443-
Defaults to the project id in the environment variable "PROJECT_OCID".
444442
kwargs :
445443
The values are passed to oci.data_science.DataScienceClient.list_model_groups.
446444
447445
Returns
448446
-------
449447
list
450-
A list of oci.data_science.models.ModelGroup objects.
448+
A list of oci.data_science.models.ModelGroupSummary objects.
451449
452450
Raises
453451
------
@@ -461,10 +459,6 @@ def list(
461459
"Unable to determine compartment ID from environment. Specify `compartment_id`."
462460
)
463461

464-
project_id = project_id or PROJECT_OCID
465-
if project_id:
466-
kwargs["project_id"] = project_id
467-
468462
if status is not None:
469463
if status not in ALLOWED_STATUS:
470464
raise ValueError(

0 commit comments

Comments
 (0)