diff --git a/pyproject.toml b/pyproject.toml index 86db7f4..49ca6f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ dynamic = ["version"] description = "WOrkflow DEfinitions" readme = "README.md" requires-python = ">=3.10" -license = {file = "LICENSE.md"} +license = { file = "LICENSE.md" } authors = [ { name = "Exabyte Inc.", email = "info@mat3ra.com" } ] @@ -24,6 +24,7 @@ dependencies = [ "mat3ra-esse", "mat3ra-mode", "mat3ra-ade", + "mat3ra-code", "mat3ra-standata" ] diff --git a/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py b/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py index 7b00e8a..83b57a4 100644 --- a/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py +++ b/src/py/mat3ra/wode/context/providers/points_grid_data_provider.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List from mat3ra.ade.context.context_provider import ContextProvider -from mat3ra.esse.models.context_providers_directory.enum import ContextProviderNameEnum from mat3ra.esse.models.context_providers_directory.points_grid_data_provider import ( GridMetricType, PointsGridDataProviderSchema, @@ -17,14 +16,16 @@ class PointsGridDataProvider(PointsGridDataProviderSchema, ContextProvider): Handles grid dimensions and shifts for reciprocal space sampling. """ - # TODO: Verify the correctness of the name - name: ContextProviderNameEnum = ContextProviderNameEnum.KGridFormDataManager + name: str = Field(default="kgrid") divisor: int = Field(default=1) dimensions: List[int] = Field(default_factory=lambda: [1, 1, 1]) shifts: List[float] = Field(default_factory=lambda: [0.0, 0.0, 0.0]) grid_metric_type: str = Field(default=GridMetricType.KPPRA) - # TODO: handle presence of material + @property + def is_edited_key(self) -> str: + return "isKgridEdited" + @property def default_data(self) -> Dict[str, Any]: return { @@ -39,7 +40,7 @@ def get_default_grid_metric_value(self, metric: str) -> float: raise NotImplementedError def calculate_dimensions( - self, grid_metric_type: str, grid_metric_value: float, units: str = "angstrom" + self, grid_metric_type: str, grid_metric_value: float, units: str = "angstrom" ) -> List[int]: raise NotImplementedError diff --git a/src/py/mat3ra/wode/units/unit.py b/src/py/mat3ra/wode/units/unit.py index 1df93e8..ea36f14 100644 --- a/src/py/mat3ra/wode/units/unit.py +++ b/src/py/mat3ra/wode/units/unit.py @@ -19,7 +19,7 @@ class Unit(WorkflowBaseUnitSchema, InMemoryEntitySnakeCase): tags: List of tags for the unit context: Context data dictionary for the unit """ - + id: str = Field(default_factory=get_uuid, alias="_id") flowchartId: str = Field(default_factory=get_uuid) # TODO: use RuntimeItemNameObjectSchema when available preProcessors: List[Any] = Field(default_factory=list) @@ -28,6 +28,7 @@ class Unit(WorkflowBaseUnitSchema, InMemoryEntitySnakeCase): results: List[Any] = Field(default_factory=list) context: Dict[str, Any] = Field(default_factory=dict) + def is_in_status(self, status: str) -> bool: return self.status == status diff --git a/tests/py/context/test_points_grid_data_provider.py b/tests/py/context/test_points_grid_data_provider.py index c710e61..99f7bcf 100644 --- a/tests/py/context/test_points_grid_data_provider.py +++ b/tests/py/context/test_points_grid_data_provider.py @@ -1,5 +1,6 @@ import pytest from mat3ra.wode.context.providers import PointsGridDataProvider +from mat3ra.esse.models.context_providers_directory.points_grid_data_provider import GridMetricType # Test data constants DIMENSIONS_DEFAULT = [1, 1, 1] @@ -8,10 +9,17 @@ SHIFTS_CUSTOM = [0.5, 0.5, 0.5] DIVISOR_DEFAULT = 1 DIVISOR_CUSTOM = 2 -DATA_CUSTOM = { - "dimensions": DIMENSIONS_CUSTOM, - "shifts": SHIFTS_CUSTOM, - "divisor": DIVISOR_CUSTOM, +GRID_METRIC_TYPE_DEFAULT = GridMetricType.KPPRA + +# Expected data structures +KGRID_DATA = { + "kgrid": { + "dimensions": DIMENSIONS_CUSTOM, + "shifts": SHIFTS_DEFAULT, + "divisor": DIVISOR_DEFAULT, + "gridMetricType": GRID_METRIC_TYPE_DEFAULT, + }, + "isKgridEdited": True, } @@ -35,60 +43,32 @@ def test_points_grid_data_provider_initialization(init_params, expected_dimensio @pytest.mark.parametrize( - "init_params,expected_dimensions,expected_shifts,expected_divisor", + "init_params,expected_data", [ ( - {"dimensions": DIMENSIONS_CUSTOM}, - DIMENSIONS_CUSTOM, - SHIFTS_DEFAULT, - DIVISOR_DEFAULT, + {"dimensions": DIMENSIONS_CUSTOM}, + KGRID_DATA, ), ], ) -def test_points_grid_data_provider_get_data(init_params, expected_dimensions, expected_shifts, expected_divisor): - kgrid_context_provider_relax = PointsGridDataProvider(**init_params) - - new_context_relax = kgrid_context_provider_relax.get_data() - - assert isinstance(new_context_relax, dict) - assert "dimensions" in new_context_relax - assert "shifts" in new_context_relax - assert "divisor" in new_context_relax - assert "gridMetricType" in new_context_relax +def test_points_grid_data_provider_get_data(init_params, expected_data): + kgrid_context_provider = PointsGridDataProvider(**init_params) + actual_data = kgrid_context_provider.get_data() + assert actual_data == expected_data["kgrid"] - assert new_context_relax["dimensions"] == expected_dimensions - assert new_context_relax["shifts"] == expected_shifts - assert new_context_relax["divisor"] == expected_divisor @pytest.mark.parametrize( - "init_params,expected_dimensions,expected_shifts,expected_divisor", + "init_params,expected_data", [ ( - {"dimensions": DIMENSIONS_CUSTOM}, - DIMENSIONS_CUSTOM, - SHIFTS_DEFAULT, - DIVISOR_DEFAULT, + {"dimensions": DIMENSIONS_CUSTOM, "is_edited": True}, + KGRID_DATA, ), ], ) -def test_points_grid_data_provider_yield_data(init_params, expected_dimensions, expected_shifts, expected_divisor): - kgrid_context_provider_relax = PointsGridDataProvider(**init_params) - - yielded_context = kgrid_context_provider_relax.yield_data() - - print(yielded_context) - assert isinstance(yielded_context, dict) - assert "KGridFormDataManager" in yielded_context - assert "isKGridFormDataManagerEdited" in yielded_context - - data = yielded_context["KGridFormDataManager"] - assert isinstance(data, dict) - assert "dimensions" in data - assert "shifts" in data - assert "divisor" in data - assert "gridMetricType" in data +def test_points_grid_data_provider_yield_data(init_params, expected_data): + kgrid_context_provider = PointsGridDataProvider(**init_params) + actual_data = kgrid_context_provider.yield_data() + assert actual_data == expected_data - assert data["dimensions"] == expected_dimensions - assert data["shifts"] == expected_shifts - assert data["divisor"] == expected_divisor diff --git a/tests/py/test_subworkflow.py b/tests/py/test_subworkflow.py index 2fbd524..d9091ac 100644 --- a/tests/py/test_subworkflow.py +++ b/tests/py/test_subworkflow.py @@ -59,3 +59,11 @@ def test_id_generation(): sw1 = Subworkflow(name=SUBWORKFLOW_NAME) sw2 = Subworkflow(name=SUBWORKFLOW_NAME) assert sw1.id != sw2.id + +def test_get_as_unit(): + sw = Subworkflow(name=SUBWORKFLOW_NAME) + unit = sw.get_as_unit() + assert unit.type == "subworkflow" + assert unit.id == sw.id + assert unit.to_dict().get("_id") == sw.id + assert unit.name == sw.name diff --git a/tests/py/test_workflow.py b/tests/py/test_workflow.py index 10cd1ca..abb35ef 100644 --- a/tests/py/test_workflow.py +++ b/tests/py/test_workflow.py @@ -2,6 +2,7 @@ from mat3ra.standata.applications import ApplicationStandata from mat3ra.standata.subworkflows import SubworkflowStandata from mat3ra.standata.workflows import WorkflowStandata + from mat3ra.wode import Subworkflow, Unit, Workflow WORKFLOW_STANDATA = WorkflowStandata()