diff --git a/taipy/core/_repository/_filesystem_repository.py b/taipy/core/_repository/_filesystem_repository.py index 721c569b36..7cd6243dd5 100644 --- a/taipy/core/_repository/_filesystem_repository.py +++ b/taipy/core/_repository/_filesystem_repository.py @@ -191,10 +191,14 @@ def __filter_files_by_config_and_owner_id( return None def __match_file_and_get_entity(self, filepath, config_and_owner_ids, filters): - if match := [(c, p) for c, p in config_and_owner_ids if c.id in filepath.name]: + if match := [(c, p) for c, p in config_and_owner_ids if (c if isinstance(c, str) else c.id) in filepath.name]: for config, owner_id in match: for fil in filters: - fil.update({"config_id": config.id, "owner_id": owner_id}) + if isinstance(config, str): + config_id = config + else: + config_id = config.id + fil.update({"config_id": config_id, "owner_id": owner_id}) if data := self.__filter_by(filepath, filters): return config, owner_id, self.__file_content_to_entity(data) @@ -227,6 +231,7 @@ def __filter_by(self, filepath: pathlib.Path, filters: Optional[List[Dict]]) -> except (FileNotFoundError, FileCannotBeRead, FileEmpty): return None + # breakpoint() for _filter in filters: conditions = [ f'"{key}": "{value}"' if value is not None else f'"{key}": null' for key, value in _filter.items() diff --git a/taipy/core/data/_data_manager.py b/taipy/core/data/_data_manager.py index e65d0ec779..6fdbfc077c 100644 --- a/taipy/core/data/_data_manager.py +++ b/taipy/core/data/_data_manager.py @@ -37,6 +37,17 @@ class _DataManager(_Manager[DataNode], _VersionMixin): _EVENT_ENTITY_TYPE = EventEntityType.DATA_NODE _repository: _DataFSRepository + @classmethod + def _get_owner_id( + cls, scope, cycle_id, scenario_id + ) -> Union[Optional[SequenceId], Optional[ScenarioId], Optional[CycleId]]: + if scope == Scope.SCENARIO: + return scenario_id + elif scope == Scope.CYCLE: + return cycle_id + else: + return None + @classmethod def _bulk_get_or_create( cls, @@ -48,13 +59,7 @@ def _bulk_get_or_create( dn_configs_and_owner_id = [] for dn_config in data_node_configs: scope = dn_config.scope - owner_id: Union[Optional[SequenceId], Optional[ScenarioId], Optional[CycleId]] - if scope == Scope.SCENARIO: - owner_id = scenario_id - elif scope == Scope.CYCLE: - owner_id = cycle_id - else: - owner_id = None + owner_id = cls._get_owner_id(scope, cycle_id, scenario_id) dn_configs_and_owner_id.append((dn_config, owner_id)) data_nodes = cls._repository._get_by_configs_and_owner_ids( @@ -174,3 +179,25 @@ def _get_by_config_id(cls, config_id: str, version_number: Optional[str] = None) for fil in filters: fil.update({"config_id": config_id}) return cls._repository._load_all(filters) + + @classmethod + def _clone( + cls, dn: DataNode, cycle_id: Optional[CycleId] = None, scenario_id: Optional[ScenarioId] = None + ) -> DataNode: + data_nodes = cls._repository._get_by_configs_and_owner_ids( + [(dn.config_id, cls._get_owner_id(dn.scope, cycle_id, scenario_id))], cls._build_filters_with_version(None) + ) + + if existing_dn := data_nodes.get((dn.config_id, dn.owner_id)): + return existing_dn + else: + cloned_dn = cls._get(dn) + + cloned_dn.id = cloned_dn._new_id(cloned_dn._config_id) + cloned_dn._owner_id = cls._get_owner_id(cloned_dn._scope, cycle_id, scenario_id) + cloned_dn._parent_ids = set() + + cloned_dn._clone_data() + + cls._set(cloned_dn) + return cloned_dn diff --git a/taipy/core/data/_file_datanode_mixin.py b/taipy/core/data/_file_datanode_mixin.py index ff87146756..53c9e72429 100644 --- a/taipy/core/data/_file_datanode_mixin.py +++ b/taipy/core/data/_file_datanode_mixin.py @@ -42,6 +42,7 @@ class _FileDataNodeMixin: _PATH_KEY = "path" _DEFAULT_PATH_KEY = "default_path" _IS_GENERATED_KEY = "is_generated" + __TAIPY_CLONED_PREFIX = "TAIPY_CLONED" __logger = _TaipyLogger._get_logger() @@ -109,12 +110,14 @@ def _get_downloadable_path(self) -> str: return "" - def _upload(self, - path: str, - upload_checker: Optional[Callable[[str, Any], bool]] = None, - editor_id: Optional[str] = None, - comment: Optional[str] = None, - **kwargs: Any) -> ReasonCollection: + def _upload( + self, + path: str, + upload_checker: Optional[Callable[[str, Any], bool]] = None, + editor_id: Optional[str] = None, + comment: Optional[str] = None, + **kwargs: Any, + ) -> ReasonCollection: """Upload a file data to the data node. Arguments: @@ -136,11 +139,15 @@ def _upload(self, from ._data_manager_factory import _DataManagerFactory reasons = ReasonCollection() - if (editor_id - and self.edit_in_progress # type: ignore[attr-defined] - and self.editor_id != editor_id # type: ignore[attr-defined] - and (not self.editor_expiration_date # type: ignore[attr-defined] - or self.editor_expiration_date > datetime.now())): # type: ignore[attr-defined] + if ( + editor_id + and self.edit_in_progress # type: ignore[attr-defined] + and self.editor_id != editor_id # type: ignore[attr-defined] + and ( + not self.editor_expiration_date # type: ignore[attr-defined] + or self.editor_expiration_date > datetime.now() + ) + ): # type: ignore[attr-defined] reasons._add_reason(self.id, DataNodeEditInProgress(self.id)) # type: ignore[attr-defined] return reasons @@ -148,8 +155,7 @@ def _upload(self, try: upload_data = self._read_from_path(str(up_path)) except Exception as err: - self.__logger.error(f"Error uploading `{up_path.name}` to data " - f"node `{self.id}`:") # type: ignore[attr-defined] + self.__logger.error(f"Error uploading `{up_path.name}` to data " f"node `{self.id}`:") # type: ignore[attr-defined] self.__logger.error(f"Error: {err}") reasons._add_reason(self.id, UploadFileCanNotBeRead(up_path.name, self.id)) # type: ignore[attr-defined] return reasons @@ -161,7 +167,8 @@ def _upload(self, self.__logger.error( f"Error with the upload checker `{upload_checker.__name__}` " f"while checking `{up_path.name}` file for upload to the data " - f"node `{self.id}`:") # type: ignore[attr-defined] + f"node `{self.id}`:" + ) # type: ignore[attr-defined] self.__logger.error(f"Error: {err}") can_upload = False @@ -171,9 +178,12 @@ def _upload(self, shutil.copy(up_path, self.path) - self.track_edit(timestamp=datetime.now(), # type: ignore[attr-defined] - editor_id=editor_id, - comment=comment, **kwargs) + self.track_edit( + timestamp=datetime.now(), # type: ignore[attr-defined] + editor_id=editor_id, + comment=comment, + **kwargs, + ) self.unlock_edit() # type: ignore[attr-defined] _DataManagerFactory._build_manager()._set(self) # type: ignore[arg-type] @@ -212,3 +222,14 @@ def _migrate_path(self, storage_type, old_path) -> str: if os.path.exists(old_path): shutil.move(old_path, new_path) return new_path + + def _clone_data_file(self, id: str) -> Optional[str]: + if os.path.exists(self.path): + folder_path, base_name = os.path.split(self.path) + new_base_path = os.path.join(folder_path, f"TAIPY_CLONE_{id}_{base_name}") + if os.path.isdir(self.path): + shutil.copytree(self.path, new_base_path) + else: + shutil.copy(self.path, new_base_path) + return new_base_path + return "" diff --git a/taipy/core/data/csv.py b/taipy/core/data/csv.py index 083215bc4e..63ed805b06 100644 --- a/taipy/core/data/csv.py +++ b/taipy/core/data/csv.py @@ -192,3 +192,9 @@ def _write(self, data: Any, columns: Optional[List[str]] = None): encoding=properties[self.__ENCODING_KEY], header=properties[self._HAS_HEADER_PROPERTY], ) + + def _clone_data(self): + new_data_path = self._clone_data_file(self.id) + del self._properties._entity_owner + self._properties[self._PATH_KEY] = new_data_path + return new_data_path diff --git a/taipy/core/data/data_node.py b/taipy/core/data/data_node.py index 08e8b2e1da..75dbed6691 100644 --- a/taipy/core/data/data_node.py +++ b/taipy/core/data/data_node.py @@ -433,22 +433,27 @@ def append(self, data, editor_id: Optional[str] = None, comment: Optional[str] = corresponding to this write. """ from ._data_manager_factory import _DataManagerFactory - if (editor_id + + if ( + editor_id and self.edit_in_progress and self.editor_id != editor_id - and (not self.editor_expiration_date or self.editor_expiration_date > datetime.now())): + and (not self.editor_expiration_date or self.editor_expiration_date > datetime.now()) + ): raise DataNodeIsBeingEdited(self.id, self.editor_id) self._append(data) self.track_edit(editor_id=editor_id, comment=comment, **kwargs) self.unlock_edit() _DataManagerFactory._build_manager()._set(self) - def write(self, - data, - job_id: Optional[JobId] = None, - editor_id: Optional[str] = None, - comment: Optional[str] = None, - **kwargs: Any): + def write( + self, + data, + job_id: Optional[JobId] = None, + editor_id: Optional[str] = None, + comment: Optional[str] = None, + **kwargs: Any, + ): """Write some data to this data node. once the data is written, the data node is unlocked and the edit is tracked. @@ -461,10 +466,12 @@ def write(self, **kwargs (Any): Extra information to attach to the edit document corresponding to this write. """ - if (editor_id + if ( + editor_id and self.edit_in_progress and self.editor_id != editor_id - and (not self.editor_expiration_date or self.editor_expiration_date > datetime.now())): + and (not self.editor_expiration_date or self.editor_expiration_date > datetime.now()) + ): raise DataNodeIsBeingEdited(self.id, self.editor_id) self._write(data) self.track_edit(job_id=job_id, editor_id=editor_id, comment=comment, **kwargs) @@ -473,12 +480,14 @@ def write(self, _DataManagerFactory._build_manager()._set(self) - def track_edit(self, - job_id: Optional[str] = None, - editor_id: Optional[str] = None, - timestamp: Optional[datetime] = None, - comment: Optional[str] = None, - **options: Any): + def track_edit( + self, + job_id: Optional[str] = None, + editor_id: Optional[str] = None, + timestamp: Optional[datetime] = None, + comment: Optional[str] = None, + **options: Any, + ): """Creates and adds a new entry in the edits attribute without writing the data. Arguments: @@ -627,15 +636,15 @@ def _get_rank(self, scenario_config_id: str) -> int: If the data node config is not part of the scenario config, 0xfffc is returned as an infinite rank. """ if not scenario_config_id: - return 0xfffb + return 0xFFFB dn_config = Config.data_nodes.get(self._config_id, None) if not dn_config: self._logger.error(f"Data node config `{self.config_id}` for data node `{self.id}` is not found.") - return 0xfffd + return 0xFFFD if not dn_config._ranks: self._logger.error(f"Data node config `{self.config_id}` for data node `{self.id}` has no rank.") - return 0xfffe - return dn_config._ranks.get(scenario_config_id, 0xfffc) + return 0xFFFE + return dn_config._ranks.get(scenario_config_id, 0xFFFC) @abstractmethod def _read(self): @@ -676,6 +685,9 @@ def _get_last_modified_datetime(cls, path: Optional[str] = None) -> Optional[dat return last_modified_datetime + def _clone_data(self): + raise NotImplementedError + @staticmethod def _class_map(): def all_subclasses(cls): diff --git a/taipy/core/data/excel.py b/taipy/core/data/excel.py index 3e39c1160f..d221cc23e5 100644 --- a/taipy/core/data/excel.py +++ b/taipy/core/data/excel.py @@ -339,3 +339,8 @@ def _write(self, data: Any): self._write_excel_with_single_sheet( data.to_excel, self._path, index=False, header=properties[self._HAS_HEADER_PROPERTY] or None ) + + def _clone_data(self): + new_data_path = self._clone_data_file(self.id) + self._properties[self._PATH_KEY] = new_data_path + return new_data_path diff --git a/taipy/core/data/json.py b/taipy/core/data/json.py index c18ab8d7b1..2479a0fe6c 100644 --- a/taipy/core/data/json.py +++ b/taipy/core/data/json.py @@ -158,6 +158,11 @@ def _write(self, data: Any): with open(self._path, "w", encoding=self.properties[self.__ENCODING_KEY]) as f: # type: ignore json.dump(data, f, indent=4, cls=self._encoder) + def _clone_data(self): + new_data_path = self._clone_data_file(self.id) + self._properties[self._PATH_KEY] = new_data_path + return new_data_path + class _DefaultJSONEncoder(json.JSONEncoder): def default(self, o): diff --git a/taipy/core/data/parquet.py b/taipy/core/data/parquet.py index 7c526b35d8..f675f0bcf7 100644 --- a/taipy/core/data/parquet.py +++ b/taipy/core/data/parquet.py @@ -249,3 +249,8 @@ def _append(self, data: Any): def _write(self, data: Any): self._write_with_kwargs(data) + + def _clone_data(self): + new_data_path = self._clone_data_file(self.id) + self._properties[self._PATH_KEY] = new_data_path + return new_data_path diff --git a/taipy/core/data/pickle.py b/taipy/core/data/pickle.py index b86e82d6c7..683456fd89 100644 --- a/taipy/core/data/pickle.py +++ b/taipy/core/data/pickle.py @@ -108,3 +108,9 @@ def _read_from_path(self, path: Optional[str] = None, **read_kwargs) -> Any: def _write(self, data): with open(self._path, "wb") as pf: pickle.dump(data, pf) + + def _clone_data(self): + new_data_path = self._clone_data_file(self.id) + del self._properties._entity_owner + self._properties[self._PATH_KEY] = new_data_path + return new_data_path diff --git a/taipy/core/scenario/_scenario_manager.py b/taipy/core/scenario/_scenario_manager.py index b8d71c25d9..c154afe45e 100644 --- a/taipy/core/scenario/_scenario_manager.py +++ b/taipy/core/scenario/_scenario_manager.py @@ -521,3 +521,56 @@ def _get_by_config_id(cls, config_id: str, version_number: Optional[str] = None) for fil in filters: fil.update({"config_id": config_id}) return cls._repository._load_all(filters) + + @classmethod + def _clone(cls, scenario: Scenario, creation_date: Optional[datetime] = None) -> Scenario: + """ + Clone a scenario. + + Arguments: + scenario (Scenario): The scenario to clone. + + Returns: + Scenario: The cloned scenario. + """ + creation_date = creation_date or datetime.now() + cloned_scenario = cls._get(scenario) + cloned_scenario.id = cloned_scenario._new_id(cloned_scenario.config_id) + + frequency = cls.__get_config(scenario).frequency + cycle = _CycleManagerFactory._build_manager()._get_or_create(frequency, creation_date) if frequency else None + cycle_id = cycle.id if cycle else None + + # TODO: update sequences + + # Clone tasks and data nodes + _task_manager = _TaskManagerFactory._build_manager() + _data_manager = _DataManagerFactory._build_manager() + + cloned_tasks = set() + for task in cloned_scenario.tasks.values(): + cloned_tasks.add(_task_manager._clone(task, cycle_id, cloned_scenario.id)) + cloned_scenario._tasks = cloned_tasks + + cloned_additional_data_nodes = set() + for data_node in cloned_scenario.additional_data_nodes.values(): + cloned_additional_data_nodes.add(_data_manager._clone(data_node, None, cloned_scenario.id)) + cloned_scenario._additional_data_nodes = cloned_additional_data_nodes + + for task in cloned_tasks: + if cloned_scenario.id not in task._parent_ids: + task._parent_ids.update([cloned_scenario.id]) + _task_manager._set(task) + + for dn in cloned_additional_data_nodes: + if cloned_scenario.id not in dn._parent_ids: + dn._parent_ids.update([cloned_scenario.id]) + _data_manager._set(dn) + + cloned_scenario._cycle = cycle + cloned_scenario._creation_date = creation_date + cloned_scenario._primary_scenario = len(cls._get_all_by_cycle(cycle)) == 0 if cycle else False + + cls._set(cloned_scenario) + + return cloned_scenario diff --git a/taipy/core/task/_task_manager.py b/taipy/core/task/_task_manager.py index 4336b8069f..fb06409ec8 100644 --- a/taipy/core/task/_task_manager.py +++ b/taipy/core/task/_task_manager.py @@ -57,6 +57,17 @@ def _set(cls, task: Task) -> None: cls.__save_data_nodes(task.output.values()) super()._set(task) + @classmethod + def _get_owner_id( + cls, scope, cycle_id, scenario_id + ) -> Union[Optional[SequenceId], Optional[ScenarioId], Optional[CycleId]]: + if scope == Scope.SCENARIO: + return scenario_id + elif scope == Scope.CYCLE: + return cycle_id + else: + return None + @classmethod def _bulk_get_or_create( cls, @@ -79,13 +90,7 @@ def _bulk_get_or_create( ] task_config_data_nodes = [data_nodes[dn_config] for dn_config in task_dn_configs] scope = min(dn.scope for dn in task_config_data_nodes) if len(task_config_data_nodes) != 0 else Scope.GLOBAL - owner_id: Union[Optional[SequenceId], Optional[ScenarioId], Optional[CycleId]] - if scope == Scope.SCENARIO: - owner_id = scenario_id - elif scope == Scope.CYCLE: - owner_id = cycle_id - else: - owner_id = None + owner_id = cls._get_owner_id(scope, cycle_id, scenario_id) tasks_configs_and_owner_id.append((task_config, owner_id)) @@ -226,3 +231,36 @@ def _get_by_config_id(cls, config_id: str, version_number: Optional[str] = None) for fil in filters: fil.update({"config_id": config_id}) return cls._repository._load_all(filters) + + @classmethod + def _clone(cls, task: Task, cycle_id: Optional[CycleId] = None, scenario_id: Optional[ScenarioId] = None) -> Task: + data_manager = _DataManagerFactory._build_manager() + + cloned_task = cls._get(task) + + inputs = [data_manager._clone(i, cycle_id, scenario_id) for i in cloned_task.input.values()] + outputs = [data_manager._clone(o, cycle_id, scenario_id) for o in cloned_task.output.values()] + + scope = min(dn.scope for dn in (inputs + outputs)) if (len(inputs) + len(outputs)) != 0 else Scope.GLOBAL + owner_id = cls._get_owner_id(scope, cycle_id, scenario_id) + + tasks_by_config = cls._repository._get_by_configs_and_owner_ids( # type: ignore + [(task.config_id, owner_id)], cls._build_filters_with_version(None) + ) + + if existing_task := tasks_by_config.get((task.config_id, owner_id)): + return existing_task + + cloned_task.id = cloned_task._new_id(cloned_task.config_id) + cloned_task._parent_ids = set() + cloned_task._owner_id = owner_id + + cloned_task._input = {i.config_id: i for i in inputs} + cloned_task._output = {o.config_id: o for o in outputs} + + for dn in set(inputs + outputs): + dn._parent_ids.update([cloned_task.id]) + data_manager._set(dn) + + cls._set(cloned_task) + return cloned_task diff --git a/taipy/core/task/task.py b/taipy/core/task/task.py index ecedf8ae4b..0bb04742f9 100644 --- a/taipy/core/task/task.py +++ b/taipy/core/task/task.py @@ -116,7 +116,7 @@ def __init__( skippable: bool = False, ) -> None: self._config_id = _validate_id(config_id) - self.id = id or TaskId(self.__ID_SEPARATOR.join([self._ID_PREFIX, self.config_id, str(uuid.uuid4())])) + self.id = id or self._new_id(config_id) self._owner_id = owner_id self._parent_ids = parent_ids or set() self._input = {dn.config_id: dn for dn in input or []} @@ -127,6 +127,11 @@ def __init__( self._properties = _Properties(self, **properties) self._init_done = True + @staticmethod + def _new_id(config_id: str) -> TaskId: + """Generate a unique task identifier.""" + return TaskId(Task.__ID_SEPARATOR.join([Task._ID_PREFIX, config_id, str(uuid.uuid4())])) + def __hash__(self) -> int: return hash(self.id) diff --git a/tests/core/data/test_csv_data_node.py b/tests/core/data/test_csv_data_node.py index dcd5f56cc1..2af952abb8 100644 --- a/tests/core/data/test_csv_data_node.py +++ b/tests/core/data/test_csv_data_node.py @@ -10,6 +10,7 @@ # specific language governing permissions and limitations under the License. import dataclasses +import filecmp import os import pathlib import re @@ -429,3 +430,13 @@ def check_data_is_positive(upload_path, upload_data): # The upload should succeed when check_data_is_positive() return True assert dn._upload(new_csv_path, upload_checker=check_data_is_positive) + + def test_clone_data_file(self): + path = os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample/example.csv") + dn = CSVDataNode("foo", Scope.SCENARIO, properties={"path": path, "exposed_type": "pandas"}) + read_data = dn.read() + assert read_data is not None + + new_file_path = str(dn._clone_data()) + assert filecmp.cmp(path, new_file_path) + os.unlink(new_file_path) diff --git a/tests/core/data/test_data_manager.py b/tests/core/data/test_data_manager.py index 7316f6d498..e93c1c9166 100644 --- a/tests/core/data/test_data_manager.py +++ b/tests/core/data/test_data_manager.py @@ -731,3 +731,21 @@ def test_get_data_nodes_by_config_id_in_multiple_versions_environment(self): assert len(_DataManager._get_by_config_id(dn_config_1.id)) == 3 assert len(_DataManager._get_by_config_id(dn_config_2.id)) == 2 + + def test_clone_data_node(self): + csv_path_inp = os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample/example.csv") + dn_config = Config.configure_csv_data_node("dn_csv_in_1", default_path=csv_path_inp) + dn = _DataManager._create_and_set(dn_config, None, None) + + old_dn_id = dn.id + + assert len(_DataManager._get_all()) == 1 + + new_dn = _DataManager._clone(dn) + old_dn = _DataManager._get(old_dn_id) + + assert old_dn.id != new_dn.id + assert len(_DataManager._get_all()) == 2 + assert old_dn.properties["path"] != new_dn.properties["path"] + assert os.path.exists(str(new_dn.properties["path"])) + os.remove(str(new_dn.properties["path"])) diff --git a/tests/core/data/test_excel_data_node.py b/tests/core/data/test_excel_data_node.py index 0a262a8e90..6234dcc045 100644 --- a/tests/core/data/test_excel_data_node.py +++ b/tests/core/data/test_excel_data_node.py @@ -9,6 +9,7 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +import filecmp import os import pathlib import re @@ -652,3 +653,13 @@ def check_data_is_positive(upload_path, upload_data): # The upload should succeed when check_data_is_positive() return True assert dn._upload(new_excel_path, upload_checker=check_data_is_positive) + + def test_clone_data_file(self): + path = os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample/example.xlsx") + dn = ExcelDataNode("foo", Scope.SCENARIO, properties={"default_path": path}) + read_data = dn.read() + assert read_data is not None + + new_file_path = str(dn._clone_data()) + assert filecmp.cmp(path, new_file_path) + os.unlink(new_file_path) diff --git a/tests/core/data/test_json_data_node.py b/tests/core/data/test_json_data_node.py index 05b2b76b02..9f50dd79a3 100644 --- a/tests/core/data/test_json_data_node.py +++ b/tests/core/data/test_json_data_node.py @@ -10,6 +10,7 @@ # specific language governing permissions and limitations under the License. import datetime +import filecmp import json import os import pathlib @@ -492,3 +493,13 @@ def check_data_keys(upload_path, upload_data): # The upload should succeed when check_data_keys() return True assert dn._upload(json_file, upload_checker=check_data_keys) + + def test_clone_data_file(self): + path = os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample/json/example_dict.json") + dn = JSONDataNode("foo", Scope.SCENARIO, properties={"path": path}) + read_data = dn.read() + assert read_data is not None + + new_file_path = str(dn._clone_data()) + assert filecmp.cmp(path, new_file_path) + os.unlink(new_file_path) diff --git a/tests/core/data/test_parquet_data_node.py b/tests/core/data/test_parquet_data_node.py index 1fc224dfa1..e829207499 100644 --- a/tests/core/data/test_parquet_data_node.py +++ b/tests/core/data/test_parquet_data_node.py @@ -9,9 +9,11 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +import filecmp import os import pathlib import re +import shutil import uuid from datetime import datetime, timedelta from importlib import util @@ -402,3 +404,13 @@ def check_data_is_positive(upload_path, upload_data): # The upload should succeed when check_data_is_positive() return True assert dn._upload(new_parquet_path, upload_checker=check_data_is_positive) + + def test_clone_data_file(self): + path = os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample/parquet_example") + dn = ParquetDataNode("foo", Scope.SCENARIO, properties={"path": path}) + read_data = dn.read() + assert read_data is not None + + new_file_path = str(dn._clone_data()) + assert filecmp.dircmp(path, new_file_path) + shutil.rmtree(new_file_path) diff --git a/tests/core/data/test_pickle_data_node.py b/tests/core/data/test_pickle_data_node.py index 05deccf0cf..d8817beb92 100644 --- a/tests/core/data/test_pickle_data_node.py +++ b/tests/core/data/test_pickle_data_node.py @@ -9,6 +9,7 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +import filecmp import os import pathlib import pickle @@ -305,3 +306,13 @@ def check_data_column(upload_path, upload_data): # The upload should succeed when check_data_column() return True assert dn._upload(pickle_file_path, upload_checker=check_data_column) + + def test_clone_data_file(self): + path = os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample/example.p") + dn = PickleDataNode("foo", Scope.SCENARIO, properties={"default_path": path}) + read_data = dn.read() + assert read_data is not None + + new_file_path = str(dn._clone_data()) + assert filecmp.cmp(path, new_file_path) + os.unlink(new_file_path) diff --git a/tests/core/scenario/test_scenario_manager.py b/tests/core/scenario/test_scenario_manager.py index 6441a32e31..6ff2e52a07 100644 --- a/tests/core/scenario/test_scenario_manager.py +++ b/tests/core/scenario/test_scenario_manager.py @@ -1553,3 +1553,93 @@ def test_filter_scenarios_by_creation_datetime(): ) assert len(filtered_scenarios) == 1 assert [s_1_1] == filtered_scenarios + + +def test_clone_scenario(): + dn_config_1 = Config.configure_pickle_data_node("dn_1", scope=Scope.SCENARIO) + dn_config_2 = Config.configure_pickle_data_node("dn_2", scope=Scope.SCENARIO) + additional_dn_config_1 = Config.configure_data_node("additional_dn_1", scope=Scope.SCENARIO) + task_config_1 = Config.configure_task("task_1", print, [dn_config_1], [dn_config_2]) + scenario_config_1 = Config.configure_scenario("scenario_1", [task_config_1], [additional_dn_config_1]) + scenario = _ScenarioManager._create(scenario_config_1) + + assert len(_ScenarioManager._get_all()) == 1 + assert len(_DataManager._get_all()) == 3 + assert len(_TaskManager._get_all()) == 1 + + new_scenario = _ScenarioManager._clone(scenario) + + assert scenario.id != new_scenario.id + assert len(_ScenarioManager._get_all()) == 2 + assert len(_DataManager._get_all()) == 6 + assert len(_TaskManager._get_all()) == 2 + + assert all(scenario.id in t.parent_ids for t in scenario.tasks.values()) + assert all(scenario.id == t.owner_id for t in scenario.tasks.values()) + assert all(scenario.id in dn.parent_ids for dn in scenario.additional_data_nodes.values()) + assert all(scenario.id == dn.owner_id for dn in scenario.data_nodes.values()) + + assert all(new_scenario.id in t.parent_ids for t in new_scenario.tasks.values()) + assert all(new_scenario.id == t.owner_id for t in new_scenario.tasks.values()) + assert all(new_scenario.id in dn.parent_ids for dn in new_scenario.additional_data_nodes.values()) + assert all(new_scenario.id == dn.owner_id for dn in new_scenario.data_nodes.values()) + + +def test_clone_scenario_with_single_GLOBAL_dn_scope(): + dn_config_1 = Config.configure_pickle_data_node("dn_1", scope=Scope.SCENARIO) + dn_config_2 = Config.configure_pickle_data_node("dn_2", scope=Scope.GLOBAL) + additional_dn_config_1 = Config.configure_data_node("additional_dn_1", scope=Scope.SCENARIO) + task_config_1 = Config.configure_task("task_1", print, [dn_config_1], [dn_config_2]) + scenario_config_1 = Config.configure_scenario("scenario_1", [task_config_1], [additional_dn_config_1]) + scenario = _ScenarioManager._create(scenario_config_1) + + assert len(_ScenarioManager._get_all()) == 1 + assert len(_DataManager._get_all()) == 3 + assert len(_TaskManager._get_all()) == 1 + + new_scenario = _ScenarioManager._clone(scenario) + + assert scenario.id != new_scenario.id + assert len(_ScenarioManager._get_all()) == 2 + assert len(_DataManager._get_all()) == 5 + assert len(_TaskManager._get_all()) == 2 + + assert all(scenario.id in t.parent_ids for t in scenario.tasks.values()) + assert all(scenario.id == t.owner_id for t in scenario.tasks.values()) + assert all(scenario.id in dn.parent_ids for dn in scenario.additional_data_nodes.values()) + assert all((scenario.id == dn.owner_id or dn.owner_id is None) for dn in scenario.data_nodes.values()) + + assert all(new_scenario.id in t.parent_ids for t in new_scenario.tasks.values()) + assert all(new_scenario.id == t.owner_id for t in new_scenario.tasks.values()) + assert all(new_scenario.id in dn.parent_ids for dn in new_scenario.additional_data_nodes.values()) + assert all((new_scenario.id == dn.owner_id or dn.owner_id is None) for dn in new_scenario.data_nodes.values()) + + +def test_clone_scenario_with_all_GLOBAL_dn_scope(): + dn_config_1 = Config.configure_pickle_data_node("dn_1", scope=Scope.GLOBAL) + dn_config_2 = Config.configure_pickle_data_node("dn_2", scope=Scope.GLOBAL) + additional_dn_config_1 = Config.configure_data_node("additional_dn_1", scope=Scope.SCENARIO) + task_config_1 = Config.configure_task("task_1", print, [dn_config_1], [dn_config_2]) + scenario_config_1 = Config.configure_scenario("scenario_1", [task_config_1], [additional_dn_config_1]) + scenario = _ScenarioManager._create(scenario_config_1) + + assert len(_ScenarioManager._get_all()) == 1 + assert len(_DataManager._get_all()) == 3 + assert len(_TaskManager._get_all()) == 1 + + new_scenario = _ScenarioManager._clone(scenario) + + assert scenario.id != new_scenario.id + assert len(_ScenarioManager._get_all()) == 2 + assert len(_DataManager._get_all()) == 4 + assert len(_TaskManager._get_all()) == 1 + + assert all(scenario.id in t.parent_ids for t in scenario.tasks.values()) + assert all(t.owner_id is None for t in scenario.tasks.values()) + assert all(scenario.id in dn.parent_ids for dn in scenario.additional_data_nodes.values()) + assert all((scenario.id == dn.owner_id or dn.owner_id is None) for dn in scenario.data_nodes.values()) + + assert all(new_scenario.id in t.parent_ids for t in new_scenario.tasks.values()) + assert all(t.owner_id is None for t in new_scenario.tasks.values()) + assert all(new_scenario.id in dn.parent_ids for dn in new_scenario.additional_data_nodes.values()) + assert all((new_scenario.id == dn.owner_id or dn.owner_id is None) for dn in new_scenario.data_nodes.values()) diff --git a/tests/core/task/test_task_manager.py b/tests/core/task/test_task_manager.py index 55d98bd875..740f1c95e1 100644 --- a/tests/core/task/test_task_manager.py +++ b/tests/core/task/test_task_manager.py @@ -483,3 +483,27 @@ def test_get_scenarios_by_config_id_in_multiple_versions_environment(): def _create_task_from_config(task_config, *args, **kwargs): return _TaskManager._bulk_get_or_create([task_config], *args, **kwargs)[0] + + +def test_clone_task(): + dn_input_config_1 = Config.configure_pickle_data_node("my_input_1", scope=Scope.SCENARIO, default_data="testing") + dn_output_config_1 = Config.configure_pickle_data_node("my_output_1") + task_config_1 = Config.configure_task("task_config_1", print, dn_input_config_1, dn_output_config_1) + task = _create_task_from_config(task_config_1) + + task_id = task.id + + assert len(_TaskManager._get_all()) == 1 + assert len(_DataManager._get_all()) == 2 + + new_task = _TaskManager._clone(task) + + assert task.id != new_task.id + assert len(_TaskManager._get_all()) == 2 + assert len(_DataManager._get_all()) == 4 + + assert all(task_id in dn.parent_ids for dn in task.data_nodes.values()) + assert all(dn.owner_id is None for dn in task.data_nodes.values()) + + assert all(new_task.id in dn.parent_ids for dn in new_task.data_nodes.values()) + assert all(dn.owner_id is None for dn in new_task.data_nodes.values())