diff --git a/.babelrc b/.babelrc index 4fce77d4..29815062 100644 --- a/.babelrc +++ b/.babelrc @@ -14,4 +14,3 @@ "@babel/plugin-proposal-class-properties" ] } - diff --git a/.nycrc b/.nycrc index 6d6542d7..e1ffae4d 100644 --- a/.nycrc +++ b/.nycrc @@ -5,4 +5,3 @@ "src/**/*.jsx" ] } - diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..aa6a6ab8 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,26 @@ +repos: + - repo: https://github.com/Exabyte-io/pre-commit-hooks + rev: 2023.6.28 + hooks: + - id: ruff + exclude: ^tests/fixtures*|^dist* + - id: black + exclude: ^tests/fixtures*|^dist* + - id: isort + exclude: ^tests/fixtures*|^dist* + - id: mypy + exclude: ^tests/fixtures*|^dist* + - id: check-yaml + exclude: ^tests/fixtures*|^dist* + - id: end-of-file-fixer + exclude: ^tests/fixtures*|^dist* + - id: trailing-whitespace + exclude: ^tests/fixtures*|^dist* + - repo: local + hooks: + - id: lint-staged + name: lint-staged + language: node + entry: npx lint-staged + verbose: true + pass_filenames: false diff --git a/.prettierrc b/.prettierrc index a15f7f0e..dd5309ad 100644 --- a/.prettierrc +++ b/.prettierrc @@ -4,4 +4,3 @@ "trailingComma": "all", "tabWidth": 4 } - diff --git a/.yamllint.yml b/.yamllint.yml deleted file mode 100644 index 2f667ceb..00000000 --- a/.yamllint.yml +++ /dev/null @@ -1,26 +0,0 @@ ---- - -extends: default - -rules: - line-length: - max: 100 - empty-lines: - level: warning - max-end: 1 - quoted-strings: - quote-type: double - required: false - indentation: - spaces: 2 - indent-sequences: consistent - document-start: disable - comments: - min-spaces-from-content: 1 - comments-indentation: disable - -ignore: | - node_modules/ - test/ - tests/ - diff --git a/00_IMPLEMENTATION_CHECKLIST_MIN.md b/00_IMPLEMENTATION_CHECKLIST_MIN.md new file mode 100644 index 00000000..1e64d2f5 --- /dev/null +++ b/00_IMPLEMENTATION_CHECKLIST_MIN.md @@ -0,0 +1,65 @@ +# MVP Implementation Checklist - Minimal for Notebook + +## `src/py/mat3ra/wode/workflows/workflow.py` + +### Class: `Workflow` + +- `add_subworkflow(subworkflow, head=False, index=-1)` +- `add_relaxation()` +- `subworkflows` [property] +- `create(config)` [classmethod] - Create workflow from standata config +- `get_unit_by_name(name=None, name_regex=None)` - Search units across all subworkflows +- `set_unit(unit=None, unit_flowchart_id=None, new_unit=None)` - Replace unit in workflow + +## `src/py/mat3ra/wode/subworkflows/subworkflow.py` + +### Class: `Subworkflow` + +- `units` [property] +- `get_as_unit()` +- `get_unit_by_name(name=None, name_regex=None)` - Search units within subworkflow +- `model` [property with setter] - Ensure model can be set directly + +## `src/py/mat3ra/wode/units/unit.py` + +### Class: `Unit` (Base) + +- `flowchartId` [field] +- `name` [field] - inherited +- `context` [field: Dict] - Context data dict +- `add_context(new_context)` - Update context data +- `get_context(key, default)` - Get context value +- `remove_context(key)` - Remove context value +- `clear_context()` - Clear all context + +## `src/py/mat3ra/wode/units/io.py` + +### Class: `IOUnit` + +- `set_materials(materials)` - Set materials +- `add_feature(feature)` - Add feature +- `remove_feature(feature)` - Remove feature +- `add_target(target)` - Add target +- `remove_target(target)` - Remove target +- `has_feature(feature)` - Check feature exists +- `has_target(target)` - Check target exists + +## `src/py/mat3ra/wode/units/map.py` + +### Class: `MapUnit` + +- `set_workflow_id(id)` - Set workflow ID + +## `src/py/mat3ra/wode/units/processing.py` + +### Class: `ProcessingUnit` + +- `set_operation(op)` - Set operation +- `set_operation_type(type)` - Set operation type +- `set_input(input)` - Set input + +## `src/py/mat3ra/wode/units/factory.py` + +### Class: `UnitFactory` + +- `create(config)` [staticmethod] - Factory to instantiate correct unit type based on config["type"] diff --git a/pyproject.toml b/pyproject.toml index d46be7d3..26178c1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,11 @@ classifiers = [ ] dependencies = [ "numpy", + "mat3ra-utils", + "mat3ra-esse @ git+https://github.com/Exabyte-io/esse.git@5fd47adb825854ac1fc4b0133cc5d0210e18a5bb", + "mat3ra-mode", + "mat3ra-ade @ git+https://github.com/Exabyte-io/ade.git@7b5bdbae13a1ef87ee64dec7e98ac7fd5661959b", + "mat3ra-standata" ] [project.optional-dependencies] @@ -62,13 +67,15 @@ extend-exclude = ''' tests\/fixtures*, examples\/.*\/.*\.py | other\/.*\/.*\.(py|ipynb) + | dist\/.* ) ''' [tool.ruff] extend-exclude = [ "src/js", - "tests/fixtures" + "tests/fixtures", + "dist" ] line-length = 120 target-version = "py310" @@ -80,6 +87,7 @@ target-version = "py310" profile = "black" multi_line_output = 3 include_trailing_comma = true +extend_skip_glob = ["dist/*"] [tool.pytest.ini_options] pythonpath = [ @@ -88,4 +96,3 @@ pythonpath = [ testpaths = [ "tests/py" ] - diff --git a/src/py/mat3ra/__init__.py b/src/py/mat3ra/__init__.py index 98cad5d6..8db66d3d 100644 --- a/src/py/mat3ra/__init__.py +++ b/src/py/mat3ra/__init__.py @@ -1,2 +1 @@ -"""mat3ra namespace package.""" - +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/src/py/mat3ra/wode/__init__.py b/src/py/mat3ra/wode/__init__.py index 5e1f4cb0..5c227ec2 100644 --- a/src/py/mat3ra/wode/__init__.py +++ b/src/py/mat3ra/wode/__init__.py @@ -1,6 +1,36 @@ -import numpy as np - - -def get_length(vec: np.ndarray) -> float: - return float(np.linalg.norm(vec)) +from .mixins import FlowchartUnitsManager +from .subworkflows import Subworkflow +from .units import ( + AssertionUnit, + AssignmentUnit, + ConditionUnit, + ExecutionUnit, + IOUnit, + MapUnit, + ProcessingUnit, + ReduceUnit, + SubworkflowUnit, + Unit, + UnitFactory, +) +from .utils import find_by_name_or_regex, generate_uuid +from .workflows import Workflow +__all__ = [ + "Unit", + "ExecutionUnit", + "AssignmentUnit", + "IOUnit", + "ConditionUnit", + "AssertionUnit", + "ProcessingUnit", + "MapUnit", + "ReduceUnit", + "SubworkflowUnit", + "Subworkflow", + "Workflow", + "UnitFactory", + "FlowchartUnitsManager", + "find_by_name_or_regex", + "generate_uuid", +] diff --git a/src/py/mat3ra/wode/context/__init__.py b/src/py/mat3ra/wode/context/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/py/mat3ra/wode/context/providers/__init__.py b/src/py/mat3ra/wode/context/providers/__init__.py new file mode 100644 index 00000000..25c36b15 --- /dev/null +++ b/src/py/mat3ra/wode/context/providers/__init__.py @@ -0,0 +1,3 @@ +from .points_grid_data_provider import PointsGridDataProvider + +__all__ = ["PointsGridDataProvider"] diff --git a/src/py/mat3ra/wode/context/providers/by_application/__init__.py b/src/py/mat3ra/wode/context/providers/by_application/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/py/mat3ra/wode/context/providers/by_application/espresso/__init__.py b/src/py/mat3ra/wode/context/providers/by_application/espresso/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/py/mat3ra/wode/context/providers/by_application/espresso/qe_pwx_context_provider.py b/src/py/mat3ra/wode/context/providers/by_application/espresso/qe_pwx_context_provider.py new file mode 100644 index 00000000..ecd96135 --- /dev/null +++ b/src/py/mat3ra/wode/context/providers/by_application/espresso/qe_pwx_context_provider.py @@ -0,0 +1,87 @@ +# TODO: We need periodic_table.js equivalent in Python +# TODO: We need all mixins equivalent in Python + +from typing import Any, Dict, List + +from ..executable_context_provider import ExecutableContextProvider +from mat3ra.esse.models.context_providers_directory.by_application.qe_pwx_context_provider import ( + QEPwxContextProviderSchema, +) + + +class QEPWXContextProvider(QEPwxContextProviderSchema, ExecutableContextProvider): + """ + Context provider for Quantum ESPRESSO pw.x settings. + """ + + # self.init_materials_context_mixin() + # self.init_method_data_context_mixin() + # self.init_workflow_context_mixin() + # self.init_job_context_mixin() + # self.init_material_context_mixin() + _material: Any = None + _materials: List[Any] = [] + + @staticmethod + def atom_symbols(material: Any) -> List[str]: + raise NotImplementedError + + @staticmethod + def unique_elements_with_labels(material: Any) -> List[str]: + raise NotImplementedError + + @staticmethod + def atomic_positions_with_constraints(material: Any) -> str: + raise NotImplementedError + + @staticmethod + def atomic_positions(material: Any) -> str: + raise NotImplementedError + + @staticmethod + def nat(material: Any) -> int: + raise NotImplementedError + + @staticmethod + def ntyp(material: Any) -> int: + raise NotImplementedError + + @staticmethod + def ntyp_with_labels(material: Any) -> int: + raise NotImplementedError + + def build_qe_pwx_context(self, material: Any) -> Dict[str, Any]: + raise NotImplementedError + + def get_data_per_material(self) -> Dict[str, Any]: + raise NotImplementedError + + def get_data(self) -> Dict[str, Any]: + raise NotImplementedError + + @property + def restart_mode(self) -> str: + raise NotImplementedError + + def get_pseudo_by_symbol(self, symbol: str) -> Any: + raise NotImplementedError + + def atomic_species(self, material: Any) -> str: + raise NotImplementedError + + def atomic_species_with_labels(self, material: Any) -> str: + raise NotImplementedError + + @staticmethod + def cell_parameters(material: Any) -> str: + raise NotImplementedError + + @staticmethod + def symbol_to_atomic_specie(symbol: str, pseudo: Any) -> str: + raise NotImplementedError + + @staticmethod + def element_and_pseudo_to_atomic_specie_with_labels( + symbol: str, pseudo: Any, label: str = "" + ) -> str: + raise NotImplementedError diff --git a/src/py/mat3ra/wode/context/providers/by_application/executable_context_provider.py b/src/py/mat3ra/wode/context/providers/by_application/executable_context_provider.py new file mode 100644 index 00000000..b53f00d7 --- /dev/null +++ b/src/py/mat3ra/wode/context/providers/by_application/executable_context_provider.py @@ -0,0 +1,8 @@ +from mat3ra.ade.context.context_provider import ContextProvider + + +class ExecutableContextProvider(ContextProvider): + """ + Context provider for executable settings. + """ + domain: str = "executable" diff --git a/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py b/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py new file mode 100644 index 00000000..23188cfa --- /dev/null +++ b/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py @@ -0,0 +1,55 @@ +from typing import Any, Dict, List + +from mat3ra.ade.context.context_provider import ContextProvider +from mat3ra.esse.models.context_providers_directory.enum import ContextProviderNameEnum +from mat3ra.esse.models.context_providers_directory.points_grid_data_provider import GridMetricType, \ + PointsGridDataProviderSchema +from pydantic import Field + + +# TODO: GlobalSetting for default KPPRA value +class PointsGridDataProvider(PointsGridDataProviderSchema, ContextProvider): + """ + Context provider for k-point/q-point grid configuration. + + Handles grid dimensions and shifts for reciprocal space sampling. + """ + # TODO: Verify the correctness of the name + name: ContextProviderNameEnum = ContextProviderNameEnum.KGridFormDataManager + divisor: int = Field(default=1) + dimensions: List[int] = Field(default_factory=lambda: [1, 1, 1]) + shifts: List[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0]) + grid_metric_type: str = Field(default=GridMetricType.KPPRA) + + # TODO: handle presence of material + @property + def default_data(self) -> Dict[str, Any]: + return { + "dimensions": self.dimensions, + "shifts": self.shifts, + "gridMetricType": self.grid_metric_type, + "divisor": self.divisor, + } + + # TODO: add a test to verify context and templates are the same as from JS implementation + def get_default_grid_metric_value(self, metric: str) -> float: + raise NotImplementedError + + def calculate_dimensions( + self, + grid_metric_type: str, + grid_metric_value: float, + units: str = "angstrom" + ) -> List[int]: + raise NotImplementedError + + def calculate_grid_metric( + self, + grid_metric_type: str, + dimensions: List[int], + units: str = "angstrom" + ) -> float: + raise NotImplementedError + + def transform_data(self, data: Dict[str, Any]) -> Dict[str, Any]: + raise NotImplementedError diff --git a/src/py/mat3ra/wode/mixins/__init__.py b/src/py/mat3ra/wode/mixins/__init__.py new file mode 100644 index 00000000..4174675c --- /dev/null +++ b/src/py/mat3ra/wode/mixins/__init__.py @@ -0,0 +1,4 @@ +from .flowchart_units_manager import FlowchartUnitsManager + +__all__ = ["FlowchartUnitsManager"] + diff --git a/src/py/mat3ra/wode/mixins/flowchart_units_manager.py b/src/py/mat3ra/wode/mixins/flowchart_units_manager.py new file mode 100644 index 00000000..d648c0c8 --- /dev/null +++ b/src/py/mat3ra/wode/mixins/flowchart_units_manager.py @@ -0,0 +1,212 @@ +from typing import List, Optional, TypeVar + +from ..units import Unit +from ..utils import find_by_name_or_regex + +T = TypeVar("T") + + +class FlowchartUnitsManager: + """ + Mixin class providing common unit operations for flowchart units. + + This mixin expects the class to have a `units: List[Unit]` attribute. + It provides common methods for managing units in both Workflow and Subworkflow classes. + """ + + units: List[Unit] + + def set_units(self, units: List[Unit]) -> None: + self.units = units + + def get_unit(self, flowchart_id: str) -> Optional[Unit]: + for unit in self.units: + if unit.flowchartId == flowchart_id: + return unit + return None + + def find_unit_by_id(self, id: str) -> Optional[Unit]: + for unit in self.units: + if getattr(unit, 'id', None) == id: + return unit + return None + + def find_unit_with_tag(self, tag: str) -> Optional[Unit]: + for unit in self.units: + if hasattr(unit, 'tags') and unit.tags is not None and tag in unit.tags: + return unit + return None + + def get_unit_by_name( + self, + name: Optional[str] = None, + name_regex: Optional[str] = None, + ) -> Optional[Unit]: + return find_by_name_or_regex(self.units, name=name, name_regex=name_regex) + + @staticmethod + def _add_to_list(items: List[T], item: T, head: bool = False, index: int = -1) -> None: + """ + Add an item to a list at a specified position. + + Args: + items: The list to add to + item: The item to add + head: If True, insert at the beginning (index 0) + index: If >= 0, insert at this specific index + If < 0, append to the end + """ + if head: + items.insert(0, item) + elif index >= 0: + items.insert(index, item) + else: + items.append(item) + + def set_units_head(self, units: List[Unit]) -> List[Unit]: + """ + Set the head flag on the first unit and unset it on all others. + + Args: + units: List of units to process + + Returns: + The modified units list + """ + if len(units) > 0: + units[0].head = True + for unit in units[1:]: + unit.head = False + return units + + def set_next_links(self, units: List[Unit]) -> List[Unit]: + """ + Re-establishes the linked next => flowchartId logic in an array of units. + + Args: + units: List of units to process + + Returns: + The modified units list + """ + flowchart_ids = [unit.flowchartId for unit in units] + + for i in range(len(units) - 1): + unit_next = getattr(units[i], 'next', None) + + if unit_next is None: + units[i].next = units[i + 1].flowchartId + if i > 0: + units[i - 1].next = units[i].flowchartId + elif unit_next not in flowchart_ids: + units[i].next = units[i + 1].flowchartId + + return units + + def _clear_link_to_unit(self, flowchart_id: str) -> None: + """ + Clear the 'next' link from any unit that points to the given flowchart_id. + + This is used to mend broken links when removing a unit. + + Args: + flowchart_id: The flowchart_id to clear links to + """ + for unit in self.units: + if getattr(unit, 'next', None) == flowchart_id: + unit.next = None + break + + def add_unit(self, unit: Unit, head: bool = False, index: int = -1) -> None: + """ + Add a unit to the units list. + + Args: + unit: Unit to add + head: If True, add at the beginning + index: If >= 0, insert at this index + """ + if len(self.units) == 0: + unit.head = True + self.set_units([unit]) + else: + self._add_to_list(self.units, unit, head, index) + self.set_units(self.set_next_links(self.set_units_head(self.units))) + + # TODO: Consider removing setNextLinks and setUnitsHead calls when flowchart designer implemented. + def remove_unit(self, flowchart_id: str) -> None: + """ + Remove a unit by its flowchartId. + + Args: + flowchart_id: The flowchartId of the unit to remove + """ + + if len(self.units) < 2: + return + + unit_to_remove = None + for unit in self.units: + if unit.flowchartId == flowchart_id: + unit_to_remove = unit + break + + if not unit_to_remove: + return + + self._clear_link_to_unit(unit_to_remove.flowchartId) + + remaining_units = [unit for unit in self.units if unit.flowchartId != flowchart_id] + units_with_head = self.set_units_head(remaining_units) + self.units = self.set_next_links(units_with_head) + + def replace_unit(self, index: int, unit: Unit) -> None: + """ + Replace a unit at a specific index. + + Args: + index: Index of the unit to replace + unit: New unit to place at that index + """ + + if 0 <= index < len(self.units): + self.units[index] = unit + self.set_units(self.set_next_links(self.set_units_head(self.units))) + + def set_unit( + self, + new_unit: Unit, + unit: Optional[Unit] = None, + unit_flowchart_id: Optional[str] = None, + ) -> bool: + """ + Replace a unit by finding it either by instance or flowchart_id. + Unit can replace itself if it was modified outside the class. + + Args: + new_unit: The new unit to set + unit: The existing unit instance to replace (optional) + unit_flowchart_id: The flowchart_id of the unit to replace (optional) + + If neither unit nor unit_flowchart_id is provided, the function will use + new_unit.flowchartId to find the existing unit to replace. + + Returns: + True if successful, False otherwise + """ + if unit is not None: + target_unit = unit + elif unit_flowchart_id is not None: + target_unit = self.get_unit(unit_flowchart_id) + else: + target_unit = self.get_unit(new_unit.flowchartId) + + if target_unit is None: + return False + + try: + unit_index = self.units.index(target_unit) + self.replace_unit(unit_index, new_unit) + return True + except ValueError: + return False diff --git a/src/py/mat3ra/wode/subworkflows/__init__.py b/src/py/mat3ra/wode/subworkflows/__init__.py new file mode 100644 index 00000000..79b20eda --- /dev/null +++ b/src/py/mat3ra/wode/subworkflows/__init__.py @@ -0,0 +1,3 @@ +from .subworkflow import Subworkflow + +__all__ = ["Subworkflow"] diff --git a/src/py/mat3ra/wode/subworkflows/subworkflow.py b/src/py/mat3ra/wode/subworkflows/subworkflow.py new file mode 100644 index 00000000..a927a6ee --- /dev/null +++ b/src/py/mat3ra/wode/subworkflows/subworkflow.py @@ -0,0 +1,104 @@ +from typing import List, Optional + +from mat3ra.ade.application import Application +from mat3ra.code.entity import InMemoryEntitySnakeCase +from mat3ra.esse.models.workflow.subworkflow import Subworkflow as SubworkflowSchema +from mat3ra.mode.method import Method +from mat3ra.mode.model import Model +from pydantic import Field + +from ..mixins import FlowchartUnitsManager +from ..units import Unit +from ..utils import generate_uuid + + +class Subworkflow(SubworkflowSchema, InMemoryEntitySnakeCase, FlowchartUnitsManager): + """ + Subworkflow class representing a logical collection of workflow units. + + Attributes: + name: Name of the subworkflow + application: Application configuration + model: Model configuration + units: List of units in the subworkflow + properties: List of properties extracted by the subworkflow + """ + + field_id: str = Field(default_factory=generate_uuid, alias="_id") + application: Application = Field( + default_factory=lambda: Application(name="", version="", build="", shortName="", summary="") + ) + model: Model = Field(default_factory=lambda: Model(type="", subtype="", method=Method(type="", subtype=""))) + units: List[Unit] = Field(default_factory=list) + + @classmethod + def from_arguments( + cls, application: Application, model: Model, method: Method, name: str, units: Optional[List] = None, + config: Optional[dict] = None + ) -> "Subworkflow": + if units is None: + units = [] + if config is None: + config = {} + + model.method = method + return cls( + name=name, + application=application, + model=model, + units=units, + **config + ) + + @property + def id(self) -> str: + return self.field_id + + @property + def properties(self) -> List[str]: + raise NotImplementedError + + @property + def is_multimaterial(self) -> bool: + raise NotImplementedError + + @property + def method_data(self): + return self.model.method.data + + @property + def context_providers(self) -> list: + """ + Get unique subworkflow context providers from all units. + + Returns: + List of unique context providers that are marked as subworkflow providers + """ + raise NotImplementedError + + @property + def context_from_assignment_units(self) -> dict: + """ + Extract context from assignment units. + + Returns: + Dictionary mapping operand names to their values from assignment units + """ + raise NotImplementedError + + def get_as_unit(self) -> Unit: + return Unit( + type="subworkflow", + _id=self.id, + name=self.name + ) + + def render(self, context: Optional[dict] = None) -> None: + """ + Render the subworkflow and all its units with the given context. + + Args: + context: Context dictionary to pass to units, combined with application, + model, methodData, and subworkflow context + """ + raise NotImplementedError diff --git a/src/py/mat3ra/wode/units/__init__.py b/src/py/mat3ra/wode/units/__init__.py new file mode 100644 index 00000000..dc37edb7 --- /dev/null +++ b/src/py/mat3ra/wode/units/__init__.py @@ -0,0 +1,25 @@ +from .assertion import AssertionUnit +from .assignment import AssignmentUnit +from .condition import ConditionUnit +from .execution import ExecutionUnit +from .factory import UnitFactory +from .io.base import IOUnit +from .map import MapUnit +from .processing import ProcessingUnit +from .reduce import ReduceUnit +from .subworkflow import SubworkflowUnit +from .unit import Unit + +__all__ = [ + "Unit", + "ExecutionUnit", + "AssignmentUnit", + "IOUnit", + "ConditionUnit", + "AssertionUnit", + "ProcessingUnit", + "MapUnit", + "ReduceUnit", + "SubworkflowUnit", + "UnitFactory", +] diff --git a/src/py/mat3ra/wode/units/assertion.py b/src/py/mat3ra/wode/units/assertion.py new file mode 100644 index 00000000..c3b709cb --- /dev/null +++ b/src/py/mat3ra/wode/units/assertion.py @@ -0,0 +1,7 @@ +from mat3ra.esse.models.workflow.unit.assertion import AssertionUnitSchema + +from .unit import Unit + + +class AssertionUnit(Unit, AssertionUnitSchema): + pass diff --git a/src/py/mat3ra/wode/units/assignment.py b/src/py/mat3ra/wode/units/assignment.py new file mode 100644 index 00000000..a06eff31 --- /dev/null +++ b/src/py/mat3ra/wode/units/assignment.py @@ -0,0 +1,10 @@ +from typing import List + +from mat3ra.esse.models.workflow.unit.assignment import AssignmentUnitSchema, WorkflowUnitInputSchema +from pydantic import Field + +from .unit import Unit + + +class AssignmentUnit(Unit, AssignmentUnitSchema): + input: List[WorkflowUnitInputSchema] = Field(default_factory=list) diff --git a/src/py/mat3ra/wode/units/condition.py b/src/py/mat3ra/wode/units/condition.py new file mode 100644 index 00000000..eb771c26 --- /dev/null +++ b/src/py/mat3ra/wode/units/condition.py @@ -0,0 +1,12 @@ +from mat3ra.esse.models.workflow.unit.condition import ConditionUnitSchema +from pydantic import Field + +from .unit import Unit + + +class ConditionUnit(Unit, ConditionUnitSchema): + statement: str = Field(default="") + then: str = Field(default="") + else_: str = Field(default="", alias="else") + input: list = Field(default_factory=list) + maxOccurrences: int = Field(default=1) diff --git a/src/py/mat3ra/wode/units/execution.py b/src/py/mat3ra/wode/units/execution.py new file mode 100644 index 00000000..1f16a15d --- /dev/null +++ b/src/py/mat3ra/wode/units/execution.py @@ -0,0 +1,14 @@ +from typing import List + +from mat3ra.ade import Executable, Flavor, Application +from mat3ra.esse.models.workflow.unit.execution import ExecutionUnitSchemaBase +from pydantic import Field + +from .unit import Unit + + +class ExecutionUnit(Unit, ExecutionUnitSchemaBase): + executable: Executable = None + flavor: Flavor = None + application: Application = None + input: List = Field(default_factory=list) diff --git a/src/py/mat3ra/wode/units/factory.py b/src/py/mat3ra/wode/units/factory.py new file mode 100644 index 00000000..59f8216f --- /dev/null +++ b/src/py/mat3ra/wode/units/factory.py @@ -0,0 +1,10 @@ +from typing import Dict + +from .unit import Unit + + +class UnitFactory: + # TODO: implement for MIN notebook + @staticmethod + def create(config: Dict) -> Unit: + raise NotImplementedError diff --git a/src/py/mat3ra/wode/units/io/__init__.py b/src/py/mat3ra/wode/units/io/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/py/mat3ra/wode/units/io/base.py b/src/py/mat3ra/wode/units/io/base.py new file mode 100644 index 00000000..02caac68 --- /dev/null +++ b/src/py/mat3ra/wode/units/io/base.py @@ -0,0 +1,14 @@ +from typing import List + +from mat3ra.esse.models.workflow.unit.io import DataIOUnitSchema +from pydantic import Field + +from mat3ra.wode.units.unit import Unit + + +class IOUnit(Unit, DataIOUnitSchema): + source: str = Field(default="") + input: List = Field(default_factory=list) + + def set_materials(self, materials: List): + raise NotImplementedError diff --git a/src/py/mat3ra/wode/units/io/data_frame_io.py b/src/py/mat3ra/wode/units/io/data_frame_io.py new file mode 100644 index 00000000..ff343010 --- /dev/null +++ b/src/py/mat3ra/wode/units/io/data_frame_io.py @@ -0,0 +1,7 @@ +from mat3ra.esse.models.workflow.unit.io import Subtype + +from .. import IOUnit + + +class DataFrameIOUnit(IOUnit): + subtype: str = Subtype.dataFrame diff --git a/src/py/mat3ra/wode/units/io/input_io.py b/src/py/mat3ra/wode/units/io/input_io.py new file mode 100644 index 00000000..bfe50ec6 --- /dev/null +++ b/src/py/mat3ra/wode/units/io/input_io.py @@ -0,0 +1,7 @@ +from mat3ra.esse.models.workflow.unit.io import Subtype + +from .. import IOUnit + + +class InputIOUnit(IOUnit): + subtype: str = Subtype.input diff --git a/src/py/mat3ra/wode/units/io/output_io.py b/src/py/mat3ra/wode/units/io/output_io.py new file mode 100644 index 00000000..bc23ac1b --- /dev/null +++ b/src/py/mat3ra/wode/units/io/output_io.py @@ -0,0 +1,7 @@ +from mat3ra.esse.models.workflow.unit.io import Subtype + +from .. import IOUnit + + +class OutputIOUnit(IOUnit): + subtype: str = Subtype.output diff --git a/src/py/mat3ra/wode/units/map.py b/src/py/mat3ra/wode/units/map.py new file mode 100644 index 00000000..d3b7ef9d --- /dev/null +++ b/src/py/mat3ra/wode/units/map.py @@ -0,0 +1,11 @@ +from mat3ra.esse.models.workflow.unit.map import MapUnitSchema +from pydantic import Field + +from .unit import Unit + + +class MapUnit(Unit, MapUnitSchema): + workflowId: str = Field(default="") + input: list = Field(default_factory=list) + def set_workflow_id(self, id: str): + raise NotImplementedError diff --git a/src/py/mat3ra/wode/units/processing.py b/src/py/mat3ra/wode/units/processing.py new file mode 100644 index 00000000..fa94b56b --- /dev/null +++ b/src/py/mat3ra/wode/units/processing.py @@ -0,0 +1,20 @@ +from typing import Any + +from mat3ra.esse.models.workflow.unit.processing import ProcessingUnitSchema +from pydantic import Field + +from .unit import Unit + + +class ProcessingUnit(Unit, ProcessingUnitSchema): + operation: str = Field(default="") + operationType: str = Field(default="") + inputData: list = Field(default_factory=list) + def set_operation(self, op: Any): + raise NotImplementedError + + def set_operation_type(self, type: str): + raise NotImplementedError + + def set_input(self, input: Any): + raise NotImplementedError diff --git a/src/py/mat3ra/wode/units/reduce.py b/src/py/mat3ra/wode/units/reduce.py new file mode 100644 index 00000000..b0eb835e --- /dev/null +++ b/src/py/mat3ra/wode/units/reduce.py @@ -0,0 +1,11 @@ +from typing import List + +from mat3ra.esse.models.workflow.unit.reduce import ReduceUnitSchema, InputItem +from pydantic import Field + +from .unit import Unit + + +class ReduceUnit(Unit, ReduceUnitSchema): + mapFlowchartId: str = Field(default="") + input: List[InputItem] = Field(default_factory=list) diff --git a/src/py/mat3ra/wode/units/subworkflow.py b/src/py/mat3ra/wode/units/subworkflow.py new file mode 100644 index 00000000..98ecb839 --- /dev/null +++ b/src/py/mat3ra/wode/units/subworkflow.py @@ -0,0 +1,7 @@ +from mat3ra.esse.models.workflow.unit.subworkflow import SubworkflowUnitSchema + +from .unit import Unit + + +class SubworkflowUnit(Unit, SubworkflowUnitSchema): + pass diff --git a/src/py/mat3ra/wode/units/unit.py b/src/py/mat3ra/wode/units/unit.py new file mode 100644 index 00000000..d0109b32 --- /dev/null +++ b/src/py/mat3ra/wode/units/unit.py @@ -0,0 +1,48 @@ +from typing import Any, Dict, List + +from mat3ra.code.entity import InMemoryEntitySnakeCase +from mat3ra.esse.models.workflow.unit.base import WorkflowBaseUnitSchema +from pydantic import Field + +from ..utils import generate_uuid + + +class Unit(WorkflowBaseUnitSchema, InMemoryEntitySnakeCase): + """ + Unit class representing a unit of computational work in a workflow. + + Attributes: + type: Type of the unit (e.g., execution, assignment, condition) + name: Name of the unit + flowchartId: Unique identifier for the unit in the flowchart + head: Whether this unit is the head of the workflow + next: Flowchart ID of the next unit + tags: List of tags for the unit + context: Context data dictionary for the unit + """ + + flowchartId: str = Field(default_factory=generate_uuid) + # TODO: use RuntimeItemNameObjectSchema when available + preProcessors: List[Any] = Field(default_factory=list) + postProcessors: List[Any] = Field(default_factory=list) + monitors: List[Any] = Field(default_factory=list) + results: List[Any] = Field(default_factory=list) + context: Dict[str, Any] = Field(default_factory=dict) + + def is_in_status(self, status: str) -> bool: + return self.status == status + + def add_context(self, new_context: Dict[str, Any]): + self.context.update(new_context) + + def set_context(self, new_context: Dict[str, Any]): + self.context = new_context + + def get_context(self, key: str, default: Any = None) -> Any: + return self.context.get(key, default) + + def remove_context(self, key: str): + self.context.pop(key, None) + + def clear_context(self): + self.context = {} diff --git a/src/py/mat3ra/wode/utils.py b/src/py/mat3ra/wode/utils.py new file mode 100644 index 00000000..cb0eb894 --- /dev/null +++ b/src/py/mat3ra/wode/utils.py @@ -0,0 +1,37 @@ +import re +from typing import Any, List, Optional + +from mat3ra.utils.uuid import get_uuid + + +def generate_uuid() -> str: + return get_uuid() + + +def find_by_name_or_regex( + items: List[Any], + name: Optional[str] = None, + name_regex: Optional[str] = None, +) -> Optional[Any]: + """ + Find an item in a list by exact name match or regex pattern. + + Args: + items: List of objects to search through + name: Exact name to match + name_regex: Regex pattern to match against names + + Returns: + First matching item or None + """ + if name: + name_lower = name.lower() + for item in items: + if item.name.lower() == name_lower: + return item + elif name_regex: + pattern = re.compile(name_regex, re.IGNORECASE) + for item in items: + if pattern.search(item.name): + return item + return None diff --git a/src/py/mat3ra/wode/workflows/__init__.py b/src/py/mat3ra/wode/workflows/__init__.py new file mode 100644 index 00000000..4a184e02 --- /dev/null +++ b/src/py/mat3ra/wode/workflows/__init__.py @@ -0,0 +1,3 @@ +from .workflow import Workflow + +__all__ = ["Workflow"] diff --git a/src/py/mat3ra/wode/workflows/workflow.py b/src/py/mat3ra/wode/workflows/workflow.py new file mode 100644 index 00000000..4f0da428 --- /dev/null +++ b/src/py/mat3ra/wode/workflows/workflow.py @@ -0,0 +1,108 @@ +from typing import Any, Dict, List, Optional + +from mat3ra.code.entity import InMemoryEntitySnakeCase +from mat3ra.esse.models.workflow import WorkflowSchema +from mat3ra.standata.subworkflows import SubworkflowStandata +from pydantic import Field + +from ..mixins import FlowchartUnitsManager +from ..subworkflows import Subworkflow +from ..units import Unit +from ..utils import generate_uuid + + +class Workflow(WorkflowSchema, InMemoryEntitySnakeCase, FlowchartUnitsManager): + """ + Workflow class representing a complete workflow configuration. + + Attributes: + name: Name of the workflow + subworkflows: List of subworkflows in the workflow + units: List of units linking the subworkflows + properties: List of properties extracted by the workflow + """ + + field_id: str = Field(default_factory=generate_uuid, alias="_id") + subworkflows: List[Subworkflow] = Field(default_factory=list) + units: List[Unit] = Field(default_factory=list) + isMultiMaterial: bool = Field(default=False) + + @property + def application(self): + if not self.subworkflows or len(self.subworkflows) == 0: + return None + + first_subworkflow = self.subworkflows[0] + return first_subworkflow.application if first_subworkflow.application else None + + @classmethod + def from_subworkflow(cls, subworkflow: Subworkflow) -> "Workflow": + raise NotImplementedError + + @classmethod + def from_subworkflows(cls, name: str, *subworkflows: Subworkflow) -> "Workflow": + raise NotImplementedError + + @property + def is_multimaterial(self) -> bool: + raise NotImplementedError + + @property + def all_subworkflows(self) -> List[Subworkflow]: + raise NotImplementedError + + # TODO: add computed_properties — "properties" will conflict with Pydantic fields + + @property + def relaxation_subworkflow(self) -> Optional[Subworkflow]: + application_name = self.application.name if self.application else None + subworkflow_standata = SubworkflowStandata() + relaxation_data = subworkflow_standata.get_relaxation_by_application(application_name) + return Subworkflow(**relaxation_data) if relaxation_data else None + + @property + def has_relaxation(self) -> bool: + return self._find_relaxation_subworkflow() is not None + + def add_subworkflow(self, subworkflow: Subworkflow, head: bool = False, index: int = -1): + self._add_to_list(self.subworkflows, subworkflow, head, index) + unit = subworkflow.get_as_unit() + self.add_unit(unit, head, index) + + def remove_subworkflow_by_id(self, id: str): + self.subworkflows = [sw for sw in self.subworkflows if sw.id != id] + + def replace_subworkflow_at_index(self, index: int, new_subworkflow: Subworkflow): + raise NotImplementedError + + def find_subworkflow_by_id(self, id: str) -> Optional[Subworkflow]: + raise NotImplementedError + + def set_context_to_unit(self, unit_name: Optional[str] = None, unit_name_regex: Optional[str] = None, + new_context: Optional[Dict[str, Any]] = None): + target_unit = self.get_unit_by_name(name=unit_name, name_regex=unit_name_regex) + target_unit.context = new_context + + def add_unit_type(self, unit_type: str, head: bool = False, index: int = -1): + raise NotImplementedError + + def _find_relaxation_subworkflow(self) -> Optional[Subworkflow]: + target_name = self.relaxation_subworkflow.name + + return next( + (swf for swf in self.subworkflows if swf.name == target_name), + None, + ) + + def add_relaxation(self) -> None: + if self.has_relaxation: + return + + relaxation_definition = self.relaxation_subworkflow + if relaxation_definition is not None: + self.add_subworkflow(relaxation_definition, head=True) + + def remove_relaxation(self) -> None: + existing = self._find_relaxation_subworkflow() + if existing is not None: + self.remove_subworkflow_by_id(existing.id) diff --git a/tests/py/context/test_points_grid_data_provider.py b/tests/py/context/test_points_grid_data_provider.py new file mode 100644 index 00000000..510b8cae --- /dev/null +++ b/tests/py/context/test_points_grid_data_provider.py @@ -0,0 +1,96 @@ +import pytest +from mat3ra.wode.context.providers import PointsGridDataProvider + +# Test data constants +DIMENSIONS_DEFAULT = [1, 1, 1] +DIMENSIONS_CUSTOM = [1, 2, 3] +SHIFTS_DEFAULT = [0.0, 0.0, 0.0] +SHIFTS_CUSTOM = [0.5, 0.5, 0.5] +DIVISOR_DEFAULT = 1 +DIVISOR_CUSTOM = 2 +DATA_CUSTOM = { + "dimensions": DIMENSIONS_CUSTOM, + "shifts": SHIFTS_CUSTOM, + "divisor": DIVISOR_CUSTOM, +} + + +@pytest.mark.parametrize( + "init_params,expected_dimensions,expected_shifts,expected_divisor", + [ + ( + {"dimensions": DIMENSIONS_CUSTOM}, + DIMENSIONS_CUSTOM, + SHIFTS_DEFAULT, + DIVISOR_DEFAULT, + ), + ], +) +def test_points_grid_data_provider_initialization( + init_params, expected_dimensions, expected_shifts, expected_divisor +): + kgrid_context_provider_relax = PointsGridDataProvider(**init_params) + + assert kgrid_context_provider_relax.dimensions == expected_dimensions + assert kgrid_context_provider_relax.shifts == expected_shifts + assert kgrid_context_provider_relax.divisor == expected_divisor + + +@pytest.mark.parametrize( + "init_params,expected_dimensions,expected_shifts,expected_divisor", + [ + ( + {"dimensions": DIMENSIONS_CUSTOM}, + DIMENSIONS_CUSTOM, + SHIFTS_DEFAULT, + DIVISOR_DEFAULT, + ), + ], +) +def test_points_grid_data_provider_get_data(init_params, expected_dimensions, expected_shifts, expected_divisor): + kgrid_context_provider_relax = PointsGridDataProvider(**init_params) + + new_context_relax = kgrid_context_provider_relax.get_data() + + assert isinstance(new_context_relax, dict) + assert "dimensions" in new_context_relax + assert "shifts" in new_context_relax + assert "divisor" in new_context_relax + assert "gridMetricType" in new_context_relax + + assert new_context_relax["dimensions"] == expected_dimensions + assert new_context_relax["shifts"] == expected_shifts + assert new_context_relax["divisor"] == expected_divisor + + +@pytest.mark.parametrize( + "init_params,expected_dimensions,expected_shifts,expected_divisor", + [ + ( + {"dimensions": DIMENSIONS_CUSTOM}, + DIMENSIONS_CUSTOM, + SHIFTS_DEFAULT, + DIVISOR_DEFAULT, + ), + ], +) +def test_points_grid_data_provider_yield_data(init_params, expected_dimensions, expected_shifts, expected_divisor): + kgrid_context_provider_relax = PointsGridDataProvider(**init_params) + + yielded_context = kgrid_context_provider_relax.yield_data() + + print(yielded_context) + assert isinstance(yielded_context, dict) + assert "KGridFormDataManager" in yielded_context + assert "isKGridFormDataManagerEdited" in yielded_context + + data = yielded_context["KGridFormDataManager"] + assert isinstance(data, dict) + assert "dimensions" in data + assert "shifts" in data + assert "divisor" in data + assert "gridMetricType" in data + + assert data["dimensions"] == expected_dimensions + assert data["shifts"] == expected_shifts + assert data["divisor"] == expected_divisor diff --git a/tests/py/test_flowchart_units_manager.py b/tests/py/test_flowchart_units_manager.py new file mode 100644 index 00000000..d5f403fa --- /dev/null +++ b/tests/py/test_flowchart_units_manager.py @@ -0,0 +1,205 @@ +import pytest + +from mat3ra.wode.units import Unit +from mat3ra.wode.workflows import Workflow + +UNIT_1_NAME = "unit_1" +UNIT_2_NAME = "unit_2" +UNIT_3_NAME = "unit_3" +UNIT_TAG = "test_tag" +FLOWCHART_ID_1 = "flowchart-id-1" +FLOWCHART_ID_2 = "flowchart-id-2" +FLOWCHART_ID_3 = "flowchart-id-3" + +UNIT_CONFIG_1 = { + "type": "execution", + "name": UNIT_1_NAME, + "flowchartId": FLOWCHART_ID_1, +} + +UNIT_CONFIG_2 = { + "type": "execution", + "name": UNIT_2_NAME, + "flowchartId": FLOWCHART_ID_2, +} + +UNIT_CONFIG_3 = { + "type": "execution", + "name": UNIT_3_NAME, + "flowchartId": FLOWCHART_ID_3, + "tags": [UNIT_TAG], +} + + +@pytest.fixture +def workflow(): + """Create a workflow instance for testing FlowchartUnitsManager methods.""" + return Workflow(name="Test Workflow") + + +@pytest.fixture +def unit_1(): + return Unit(**UNIT_CONFIG_1) + + +@pytest.fixture +def unit_2(): + return Unit(**UNIT_CONFIG_2) + + +@pytest.fixture +def unit_3(): + return Unit(**UNIT_CONFIG_3) + + +def test_set_units(workflow, unit_1, unit_2): + units = [unit_1, unit_2] + workflow.set_units(units) + assert len(workflow.units) == 2 + assert workflow.units[0].flowchartId == FLOWCHART_ID_1 + assert workflow.units[1].flowchartId == FLOWCHART_ID_2 + + +def test_get_unit(workflow, unit_1, unit_2): + workflow.set_units([unit_1, unit_2]) + found_unit = workflow.get_unit(FLOWCHART_ID_1) + assert found_unit is not None + assert found_unit.flowchartId == FLOWCHART_ID_1 + assert found_unit.name == UNIT_1_NAME + + +def test_find_unit_by_id(workflow, unit_1, unit_2): + workflow.set_units([unit_1, unit_2]) + found_unit = workflow.find_unit_by_id(unit_1.field_id) + assert found_unit is not None + assert found_unit.field_id == unit_1.field_id + + +def test_find_unit_with_tag(workflow, unit_3): + workflow.set_units([unit_3]) + found_unit = workflow.find_unit_with_tag(UNIT_TAG) + assert found_unit is not None + assert UNIT_TAG in found_unit.tags + + +@pytest.mark.parametrize("search_name,expected_name", [ + (UNIT_1_NAME, UNIT_1_NAME), + (UNIT_1_NAME.upper(), UNIT_1_NAME), # Case insensitive + (UNIT_2_NAME, UNIT_2_NAME), +]) +def test_get_unit_by_name(workflow, unit_1, unit_2, search_name, expected_name): + workflow.set_units([unit_1, unit_2]) + found_unit = workflow.get_unit_by_name(name=search_name) + assert found_unit is not None + assert found_unit.name == expected_name + + +def test_get_unit_by_name_regex(workflow, unit_1, unit_2): + workflow.set_units([unit_1, unit_2]) + found_unit = workflow.get_unit_by_name(name_regex=r"unit_\d") + assert found_unit is not None + assert found_unit.name in [UNIT_1_NAME, UNIT_2_NAME] + + +def test_set_units_head(workflow, unit_1, unit_2, unit_3): + units = [unit_1, unit_2, unit_3] + result = workflow.set_units_head(units) + assert result[0].head is True + assert result[1].head is False + assert result[2].head is False + + +def test_set_next_links(workflow, unit_1, unit_2, unit_3): + units = [unit_1, unit_2, unit_3] + result = workflow.set_next_links(units) + assert result[0].next == FLOWCHART_ID_2 + assert result[1].next == FLOWCHART_ID_3 + assert result[2].next is None or result[2].next == "" + + +def test_clear_link_to_unit(workflow, unit_1, unit_2): + unit_1.next = FLOWCHART_ID_2 + workflow.set_units([unit_1, unit_2]) + workflow._clear_link_to_unit(FLOWCHART_ID_2) + assert unit_1.next is None + + +def test_add_unit(workflow, unit_1, unit_2): + workflow.add_unit(unit_1) + assert len(workflow.units) == 1 + assert workflow.units[0].head is True + assert workflow.units[0].flowchartId == FLOWCHART_ID_1 + + workflow.add_unit(unit_2) + assert len(workflow.units) == 2 + assert workflow.units[0].head is True + assert workflow.units[1].head is False + assert workflow.units[0].next == FLOWCHART_ID_2 + + +@pytest.mark.parametrize("head,expected_order", [ + (True, [FLOWCHART_ID_3, FLOWCHART_ID_1, FLOWCHART_ID_2]), + (False, [FLOWCHART_ID_1, FLOWCHART_ID_2, FLOWCHART_ID_3]), +]) +def test_add_unit_head_parameter(workflow, unit_1, unit_2, unit_3, head, expected_order): + workflow.add_unit(unit_1) + workflow.add_unit(unit_2) + workflow.add_unit(unit_3, head=head) + + actual_order = [u.flowchartId for u in workflow.units] + assert actual_order == expected_order + + +def test_remove_unit(workflow, unit_1, unit_2, unit_3): + workflow.set_units([unit_1, unit_2, unit_3]) + workflow.remove_unit(FLOWCHART_ID_2) + + assert len(workflow.units) == 2 + assert workflow.units[0].flowchartId == FLOWCHART_ID_1 + assert workflow.units[1].flowchartId == FLOWCHART_ID_3 + assert workflow.units[0].next == FLOWCHART_ID_3 + + +def test_replace_unit(workflow, unit_1, unit_2): + workflow.set_units([unit_1]) + workflow.replace_unit(0, unit_2) + + assert len(workflow.units) == 1 + assert workflow.units[0].flowchartId == FLOWCHART_ID_2 + + +@pytest.mark.parametrize("provide_unit,provide_id,should_succeed", [ + (True, False, True), # Provide unit instance + (False, True, True), # Provide flowchart_id + (False, False, True), # Provide neither (use new_unit.flowchartId) +]) +def test_set_unit(workflow, unit_1, unit_2, provide_unit, provide_id, should_succeed): + workflow.set_units([unit_1]) + + # Create new unit with same flowchart_id + new_unit = Unit(type="execution", name="Updated Unit", flowchartId=FLOWCHART_ID_1) + + if provide_unit: + result = workflow.set_unit(new_unit, unit=unit_1) + elif provide_id: + result = workflow.set_unit(new_unit, unit_flowchart_id=FLOWCHART_ID_1) + else: + result = workflow.set_unit(new_unit) + + assert result is should_succeed + if should_succeed: + assert workflow.units[0].name == "Updated Unit" + + +def test_set_unit_replaces_itself(workflow): + unit = Unit(type="execution", name="Original", flowchartId=FLOWCHART_ID_1) + workflow.add_unit(unit) + + # Modify the unit outside + unit.name = "Modified" + + # Replace it with itself using flowchartId + result = workflow.set_unit(unit) + + assert result is True + assert workflow.units[0].name == "Modified" diff --git a/tests/py/test_sample.py b/tests/py/test_sample.py deleted file mode 100644 index 6f31cc17..00000000 --- a/tests/py/test_sample.py +++ /dev/null @@ -1,11 +0,0 @@ -import numpy as np -from mat3ra.wode import get_length - - -def test_get_length(): - """Test that get_length returns correct type and value.""" - vec = np.array([1, 2]) - result = get_length(vec) - assert isinstance(result, float) - assert np.isclose(result, np.sqrt(5)) - diff --git a/tests/py/test_subworkflow.py b/tests/py/test_subworkflow.py new file mode 100644 index 00000000..ac653210 --- /dev/null +++ b/tests/py/test_subworkflow.py @@ -0,0 +1,74 @@ +import pytest +from mat3ra.ade.application import Application +from mat3ra.mode.method import Method +from mat3ra.mode.model import Model +from mat3ra.standata.applications import ApplicationStandata +from mat3ra.wode import Subworkflow, Unit + +SUBWORKFLOW_NAME = "Total Energy" +SUBWORKFLOW_APPLICATION = Application(**ApplicationStandata.get_by_name_first_match("espresso")) +SUBWORKFLOW_METHOD = Method(type="pseudopotential", subtype="us") +SUBWORKFLOW_MODEL = Model(type="dft", subtype="gga", method=SUBWORKFLOW_METHOD) +SUBWORKFLOW_PROPERTIES = ["total_energy", "pressure"] + +UNIT_CONFIG = { + "type": "execution", + "name": "pw_scf", + "flowchartId": "unit-flowchart-id", + "head": True, +} + + +def test_creation(): + sw = Subworkflow(name=SUBWORKFLOW_NAME) + assert sw.name == SUBWORKFLOW_NAME + + +@pytest.mark.parametrize("app_name", ["espresso", "vasp"]) +def test_application(app_name): + app_data = ApplicationStandata.get_by_name_first_match(app_name) + application = Application(**app_data) + sw = Subworkflow(name=SUBWORKFLOW_NAME, application=application) + assert sw.application.name == app_name + assert sw.application.version == app_data["version"] + + +@pytest.mark.parametrize( + "model_type,model_subtype", + [ + ("dft", "gga"), + ("dft", "lda"), + ], +) +def test_model(model_type, model_subtype): + method = Method(type="pseudopotential", subtype="us") + model = Model(type=model_type, subtype=model_subtype, method=method) + sw = Subworkflow(name=SUBWORKFLOW_NAME, model=model) + assert sw.model.type == model_type + assert sw.model.subtype == model_subtype + + +def test_properties(): + sw = Subworkflow(name=SUBWORKFLOW_NAME, properties=SUBWORKFLOW_PROPERTIES) + assert sw.properties == SUBWORKFLOW_PROPERTIES + + +def test_with_units(): + unit = Unit(**UNIT_CONFIG) + sw = Subworkflow(name=SUBWORKFLOW_NAME, units=[unit]) + assert len(sw.units) == 1 + assert sw.units[0].name == UNIT_CONFIG["name"] + + +def test_field_id_generation(): + sw1 = Subworkflow(name=SUBWORKFLOW_NAME) + sw2 = Subworkflow(name=SUBWORKFLOW_NAME) + assert sw1.field_id != sw2.field_id + + +@pytest.mark.skip(reason="Implementation not complete") +def test_to_dict(): + sw = Subworkflow(name=SUBWORKFLOW_NAME, application=SUBWORKFLOW_APPLICATION) + data = sw.to_dict() + assert data["name"] == SUBWORKFLOW_NAME + assert data["application"]["name"] == SUBWORKFLOW_APPLICATION.name diff --git a/tests/py/test_unit.py b/tests/py/test_unit.py new file mode 100644 index 00000000..c55cd0a0 --- /dev/null +++ b/tests/py/test_unit.py @@ -0,0 +1,80 @@ +import pytest +from mat3ra.standata.applications import ApplicationStandata +from mat3ra.standata.workflows import WorkflowStandata + +from mat3ra.wode import Unit + +WORKFLOW_STANDATA = WorkflowStandata() +APPLICATION_STANDATA = ApplicationStandata() + +DEFAULT_WF_NAME = WORKFLOW_STANDATA.get_default()["name"] +APPLICATION_ESPRESSO = APPLICATION_STANDATA.get_by_name_first_match("espresso")["name"] + +UNIT_FLOWCHART_ID = "abc-123-def" +UNIT_NEXT_ID = "next-456" + +NEW_CONTEXT_RELAX = { + "kgrid": {"density": 0.5}, + "convergence": {"threshold": 1e-6} +} + +UNIT_CONFIG_EXECUTION = { + "type": "execution", + "name": "pw_scf", + "flowchartId": UNIT_FLOWCHART_ID, + "head": True, +} + +UNIT_CONFIG_ASSIGNMENT = { + "type": "assignment", + "name": "kgrid", + "flowchartId": "kgrid-flowchart-id", + "head": False, +} + + +@pytest.mark.parametrize("config", [UNIT_CONFIG_EXECUTION, UNIT_CONFIG_ASSIGNMENT]) +def test_creation(config): + unit = Unit(**config) + assert unit.type == config["type"] + assert unit.name == config["name"] + + +def test_snake_case_properties(): + unit = Unit(**UNIT_CONFIG_EXECUTION) + assert unit.flowchart_id == UNIT_FLOWCHART_ID + + +@pytest.mark.parametrize("head_value", [True, False]) +def test_head_property(head_value): + config = {**UNIT_CONFIG_EXECUTION, "head": head_value} + unit = Unit(**config) + assert unit.head == head_value + + +def test_next_property(): + config = {**UNIT_CONFIG_EXECUTION, "next": UNIT_NEXT_ID} + unit = Unit(**config) + assert unit.next == UNIT_NEXT_ID + + +def test_to_dict(): + unit = Unit(**UNIT_CONFIG_EXECUTION) + data = unit.to_dict() + assert data["type"] == UNIT_CONFIG_EXECUTION["type"] + assert data["name"] == UNIT_CONFIG_EXECUTION["name"] + assert data["head"] is True + + +def test_add_context(): + unit = Unit(**{**UNIT_CONFIG_EXECUTION, "name": "relaxation step"}) + + assert unit is not None + assert "relax" in unit.name.lower() + + unit.add_context(NEW_CONTEXT_RELAX) + + assert "kgrid" in unit.context + assert "convergence" in unit.context + assert unit.context["kgrid"] == NEW_CONTEXT_RELAX["kgrid"] + assert unit.context["convergence"] == NEW_CONTEXT_RELAX["convergence"] diff --git a/tests/py/test_workflow.py b/tests/py/test_workflow.py new file mode 100644 index 00000000..355094bc --- /dev/null +++ b/tests/py/test_workflow.py @@ -0,0 +1,183 @@ +import pytest +from mat3ra.standata.applications import ApplicationStandata +from mat3ra.standata.subworkflows import SubworkflowStandata +from mat3ra.standata.workflows import WorkflowStandata + +from mat3ra.wode import Subworkflow, Unit, Workflow + +WORKFLOW_STANDATA = WorkflowStandata() +SUBWORKFLOW_STANDATA = SubworkflowStandata() +APPLICATION_STANDATA = ApplicationStandata() + +WORKFLOW_NAME = WORKFLOW_STANDATA.get_by_name_first_match( + "band_gap" +)["name"] +SUBWORKFLOW_NAME = SUBWORKFLOW_STANDATA.get_by_name_first_match( + "pw_scf" +)["name"] +DEFAULT_WF_NAME = WORKFLOW_STANDATA.get_default()["name"] + +APPLICATION_ESPRESSO = APPLICATION_STANDATA.get_by_name_first_match("espresso")["name"] +APPLICATION_VASP = APPLICATION_STANDATA.get_by_name_first_match("vasp")["name"] +APPLICATION_PYTHON = APPLICATION_STANDATA.get_by_name_first_match("python")["name"] +RELAXATION_NAME = SUBWORKFLOW_STANDATA.get_relaxation_by_application(APPLICATION_ESPRESSO)["name"] + +UNIT_CONFIG = { + "type": "execution", + "name": "pw_scf", + "flowchartId": "unit-flowchart-id", + "head": True, +} + + +def test_creation(): + wf = Workflow(name=WORKFLOW_NAME) + assert wf.name == WORKFLOW_NAME + + +def test_subworkflows(): + sw = Subworkflow(name=SUBWORKFLOW_NAME) + wf = Workflow(name=WORKFLOW_NAME, subworkflows=[sw]) + assert len(wf.subworkflows) == 1 + assert wf.subworkflows[0].name == SUBWORKFLOW_NAME + + +def test_with_units(): + unit = Unit(**UNIT_CONFIG) + wf = Workflow(name=WORKFLOW_NAME, units=[unit]) + assert len(wf.units) == 1 + assert wf.units[0].name == UNIT_CONFIG["name"] + + +def test_field_id_generation(): + wf1 = Workflow(name=WORKFLOW_NAME) + wf2 = Workflow(name=WORKFLOW_NAME) + assert wf1.field_id != wf2.field_id + + +def test_to_dict(): + wf = Workflow(name=WORKFLOW_NAME) + data = wf.to_dict() + assert data["name"] == WORKFLOW_NAME + + +def test_add_subworkflow(): + wf = Workflow(name=WORKFLOW_NAME) + sw = Subworkflow(name=SUBWORKFLOW_NAME) + wf.add_subworkflow(sw) + assert len(wf.subworkflows) == 1 + assert wf.subworkflows[0].name == SUBWORKFLOW_NAME + assert len(wf.units) == 1 + assert wf.units[0].name == SUBWORKFLOW_NAME + assert wf.units[0].type == "subworkflow" + + +@pytest.mark.parametrize( + "application,has_relaxation", + [ + (APPLICATION_ESPRESSO, True), + (APPLICATION_VASP, True), + (APPLICATION_PYTHON, False), + ], +) +def test_get_relaxation_subworkflow(application, has_relaxation): + workflows = WORKFLOW_STANDATA.get_by_categories(application, DEFAULT_WF_NAME) + if not workflows: + pytest.skip(f"No {DEFAULT_WF_NAME} workflow found for {application}") + + workflow_config = workflows[0] + wf = Workflow(**workflow_config) + + result = wf.relaxation_subworkflow + if has_relaxation: + assert result is not None + assert result.name == RELAXATION_NAME + assert hasattr(result, 'name') + else: + assert result is None + + +@pytest.mark.parametrize( + "application", + [APPLICATION_ESPRESSO, APPLICATION_VASP], +) +def test_add_relaxation(application): + workflows = WORKFLOW_STANDATA.get_by_categories(application, DEFAULT_WF_NAME) + if not workflows: + pytest.skip(f"No {DEFAULT_WF_NAME} workflow found for {application}") + + workflow_config = workflows[0] + wf = Workflow(**workflow_config) + + initial_subworkflow_count = len(wf.subworkflows) + assert not wf.has_relaxation + + wf.add_relaxation() + + assert wf.has_relaxation + assert len(wf.subworkflows) == initial_subworkflow_count + 1 + assert wf.subworkflows[0].name == wf.relaxation_subworkflow.name + + +@pytest.mark.parametrize( + "application", + [APPLICATION_ESPRESSO, APPLICATION_VASP], +) +def test_remove_relaxation(application): + workflows = WORKFLOW_STANDATA.get_by_categories(application, DEFAULT_WF_NAME) + if not workflows: + pytest.skip(f"No {DEFAULT_WF_NAME} workflow found for {application}") + + workflow_config = workflows[0] + wf = Workflow(**workflow_config) + + wf.add_relaxation() + assert wf.has_relaxation + initial_subworkflow_count = len(wf.subworkflows) + + wf.remove_relaxation() + + assert not wf.has_relaxation + assert len(wf.subworkflows) == initial_subworkflow_count - 1 + + +@pytest.mark.parametrize( + "method", + [ + "only_new_unit", + "with_unit_instance", + "with_flowchart_id", + ], +) +def test_set_unit(method): + workflows = WORKFLOW_STANDATA.get_by_categories(APPLICATION_ESPRESSO, DEFAULT_WF_NAME) + if not workflows: + pytest.skip(f"No {DEFAULT_WF_NAME} workflow found for {APPLICATION_ESPRESSO}") + + workflow_config = workflows[0] + wf = Workflow(**workflow_config) + + wf.add_relaxation() + + unit_to_modify = wf.get_unit_by_name(name_regex="relax") + assert unit_to_modify is not None + + new_context = {"test_key": "test_value", "another_key": 42} + unit_to_modify.add_context(new_context) + + if method == "only_new_unit": + success = wf.set_unit(unit_to_modify) + elif method == "with_unit_instance": + original_unit = wf.get_unit_by_name(name_regex="relax") + success = wf.set_unit(unit_to_modify, unit=original_unit) + elif method == "with_flowchart_id": + flowchart_id = unit_to_modify.flowchartId + success = wf.set_unit(unit_to_modify, unit_flowchart_id=flowchart_id) + + assert success is True + + updated_unit = wf.get_unit_by_name(name_regex="relax") + assert "test_key" in updated_unit.context + assert "another_key" in updated_unit.context + assert updated_unit.context["test_key"] == "test_value" + assert updated_unit.context["another_key"] == 42 diff --git a/tests/py/units/__init__.py b/tests/py/units/__init__.py new file mode 100644 index 00000000..04a283c1 --- /dev/null +++ b/tests/py/units/__init__.py @@ -0,0 +1,2 @@ +# Unit tests package + diff --git a/tests/py/units/test_assertion_unit.py b/tests/py/units/test_assertion_unit.py new file mode 100644 index 00000000..01d5929f --- /dev/null +++ b/tests/py/units/test_assertion_unit.py @@ -0,0 +1,7 @@ +from mat3ra.wode import AssertionUnit + + +def test_default_values(): + unit = AssertionUnit(type="assertion", name="test", statement="x > 0") + assert unit.type == "assertion" + assert unit.statement == "x > 0" diff --git a/tests/py/units/test_assignment_unit.py b/tests/py/units/test_assignment_unit.py new file mode 100644 index 00000000..3a37df96 --- /dev/null +++ b/tests/py/units/test_assignment_unit.py @@ -0,0 +1,7 @@ +from mat3ra.wode import AssignmentUnit + + +def test_default_values(): + unit = AssignmentUnit(type="assignment", name="test", operand="x", value="1") + assert unit.type == "assignment" + assert unit.operand == "x" diff --git a/tests/py/units/test_condition_unit.py b/tests/py/units/test_condition_unit.py new file mode 100644 index 00000000..8874bd9f --- /dev/null +++ b/tests/py/units/test_condition_unit.py @@ -0,0 +1,13 @@ +from mat3ra.wode import ConditionUnit + + +def test_default_values(): + unit = ConditionUnit( + type="condition", + name="test", + statement="x > 0", + then="a", + else_="b", + ) + assert unit.type == "condition" + assert unit.statement == "x > 0" diff --git a/tests/py/units/test_execution_unit.py b/tests/py/units/test_execution_unit.py new file mode 100644 index 00000000..19383c9f --- /dev/null +++ b/tests/py/units/test_execution_unit.py @@ -0,0 +1,6 @@ +from mat3ra.wode import ExecutionUnit + + +def test_default_values(): + unit = ExecutionUnit(type="execution", name="test") + assert unit.type == "execution" diff --git a/tests/py/units/test_io_unit.py b/tests/py/units/test_io_unit.py new file mode 100644 index 00000000..22d846eb --- /dev/null +++ b/tests/py/units/test_io_unit.py @@ -0,0 +1,7 @@ +from mat3ra.wode import IOUnit + + +def test_default_values(): + unit = IOUnit(type="io", name="test", subtype="input", source="api") + assert unit.type == "io" + assert unit.subtype.value == "input" diff --git a/tests/py/units/test_map_unit.py b/tests/py/units/test_map_unit.py new file mode 100644 index 00000000..ad5bc048 --- /dev/null +++ b/tests/py/units/test_map_unit.py @@ -0,0 +1,7 @@ +from mat3ra.wode import MapUnit + + +def test_default_values(): + unit = MapUnit(type="map", name="test", workflowId="wf-123") + assert unit.type == "map" + assert unit.workflowId == "wf-123" diff --git a/tests/py/units/test_processing_unit.py b/tests/py/units/test_processing_unit.py new file mode 100644 index 00000000..092f981f --- /dev/null +++ b/tests/py/units/test_processing_unit.py @@ -0,0 +1,6 @@ +from mat3ra.wode import ProcessingUnit + + +def test_default_values(): + unit = ProcessingUnit(type="processing", name="test") + assert unit.type == "processing" diff --git a/tests/py/units/test_reduce_unit.py b/tests/py/units/test_reduce_unit.py new file mode 100644 index 00000000..f05dd698 --- /dev/null +++ b/tests/py/units/test_reduce_unit.py @@ -0,0 +1,7 @@ +from mat3ra.wode import ReduceUnit + + +def test_default_values(): + unit = ReduceUnit(type="reduce", name="test", mapFlowchartId="map-123") + assert unit.type == "reduce" + assert unit.mapFlowchartId == "map-123" diff --git a/tests/py/units/test_subworkflow_unit.py b/tests/py/units/test_subworkflow_unit.py new file mode 100644 index 00000000..623081de --- /dev/null +++ b/tests/py/units/test_subworkflow_unit.py @@ -0,0 +1,7 @@ +from mat3ra.wode import SubworkflowUnit + + +def test_default_values(): + unit = SubworkflowUnit(type="subworkflow", name="test") + assert unit.type == "subworkflow" + assert unit.name == "test"