|
35 | 35 | from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
|
36 | 36 | from ads.jobs.builders.runtimes.python_runtime import GitPythonRuntime
|
37 | 37 |
|
38 |
| -from ads.common.dsc_file_system import OCIFileStorage, DSCFileSystemManager, OCIObjectStorage |
| 38 | +from ads.common.dsc_file_system import ( |
| 39 | + OCIFileStorage, |
| 40 | + DSCFileSystemManager, |
| 41 | + OCIObjectStorage, |
| 42 | +) |
39 | 43 |
|
40 | 44 | logger = logging.getLogger(__name__)
|
41 | 45 |
|
@@ -1454,11 +1458,14 @@ def _update_job_infra(self, dsc_job: DSCJob) -> DataScienceJob:
|
1454 | 1458 | if value:
|
1455 | 1459 | dsc_job.job_infrastructure_configuration_details[camel_attr] = value
|
1456 | 1460 |
|
1457 |
| - if ( |
1458 |
| - not dsc_job.job_infrastructure_configuration_details.get("shapeName", "").endswith("Flex") |
1459 |
| - and dsc_job.job_infrastructure_configuration_details.get("jobShapeConfigDetails") |
| 1461 | + if not dsc_job.job_infrastructure_configuration_details.get( |
| 1462 | + "shapeName", "" |
| 1463 | + ).endswith("Flex") and dsc_job.job_infrastructure_configuration_details.get( |
| 1464 | + "jobShapeConfigDetails" |
1460 | 1465 | ):
|
1461 |
| - raise ValueError("Shape config is not required for non flex shape from user end.") |
| 1466 | + raise ValueError( |
| 1467 | + "Shape config is not required for non flex shape from user end." |
| 1468 | + ) |
1462 | 1469 |
|
1463 | 1470 | if dsc_job.job_infrastructure_configuration_details.get("subnetId"):
|
1464 | 1471 | dsc_job.job_infrastructure_configuration_details[
|
@@ -1495,7 +1502,10 @@ def init(self) -> DataScienceJob:
|
1495 | 1502 | self.build()
|
1496 | 1503 | .with_compartment_id(self.compartment_id or "{Provide a compartment OCID}")
|
1497 | 1504 | .with_project_id(self.project_id or "{Provide a project OCID}")
|
1498 |
| - .with_subnet_id(self.subnet_id or "{Provide a subnet OCID or remove this field if you use a default networking}") |
| 1505 | + .with_subnet_id( |
| 1506 | + self.subnet_id |
| 1507 | + or "{Provide a subnet OCID or remove this field if you use a default networking}" |
| 1508 | + ) |
1499 | 1509 | )
|
1500 | 1510 |
|
1501 | 1511 | def create(self, runtime, **kwargs) -> DataScienceJob:
|
@@ -1552,7 +1562,7 @@ def run(
|
1552 | 1562 | freeform_tags=None,
|
1553 | 1563 | defined_tags=None,
|
1554 | 1564 | wait=False,
|
1555 |
| - **kwargs |
| 1565 | + **kwargs, |
1556 | 1566 | ) -> DataScienceJobRun:
|
1557 | 1567 | """Runs a job on OCI Data Science job
|
1558 | 1568 |
|
@@ -1610,8 +1620,11 @@ def run(
|
1610 | 1620 | freeform_tags=freeform_tags,
|
1611 | 1621 | defined_tags=defined_tags,
|
1612 | 1622 | wait=wait,
|
1613 |
| - **kwargs |
| 1623 | + **kwargs, |
1614 | 1624 | )
|
| 1625 | + # A Runtime class may define customized run() method. |
| 1626 | + # Use the customized method if the run() method is defined by the runtime. |
| 1627 | + # Otherwise, use the default run() method defined in this class. |
1615 | 1628 | if hasattr(self.runtime, "run"):
|
1616 | 1629 | return self.runtime.run(self.dsc_job, **kwargs)
|
1617 | 1630 | return self.dsc_job.run(**kwargs)
|
|
0 commit comments