diff --git a/antarest/core/exceptions.py b/antarest/core/exceptions.py index 324c71ccf4..857ea1462f 100644 --- a/antarest/core/exceptions.py +++ b/antarest/core/exceptions.py @@ -189,8 +189,8 @@ def __init__(self, *area_ids: str) -> None: ids = ", ".join(f"'{a}'" for a in area_ids) msg = { 0: "All areas are found", - 1: f"{count} area is not found: {ids}", - 2: f"{count} areas are not found: {ids}", + 1: f"Area is not found: {ids}", + 2: f"Areas are not found: {ids}", }[min(count, 2)] super().__init__(HTTPStatus.NOT_FOUND, msg) diff --git a/antarest/study/business/correlation_management.py b/antarest/study/business/correlation_management.py new file mode 100644 index 0000000000..763dbdc12d --- /dev/null +++ b/antarest/study/business/correlation_management.py @@ -0,0 +1,369 @@ +""" +Management of spatial correlations between the different generators. +The generators are of the same category and can be hydraulic, wind, load or solar. +""" +import collections +from typing import Dict, List, Sequence + +import numpy as np +import numpy.typing as npt +from antarest.core.exceptions import AreaNotFound +from antarest.study.business.area_management import AreaInfoDTO +from antarest.study.business.utils import ( + FormFieldsBaseModel, + execute_or_add_commands, +) +from antarest.study.model import Study +from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy +from antarest.study.storage.storage_service import StudyStorageService +from antarest.study.storage.variantstudy.model.command.update_config import ( + UpdateConfig, +) +from pydantic import conlist, validator + + +class CorrelationField(FormFieldsBaseModel): + """ + Model for correlation coefficients of a given area. + + Attributes: + area_id: Area identifier. + coefficient: correlation coefficients in percentage (-100 <= coefficient <= 100). + """ + + class Config: + allow_population_by_field_name = True + + area_id: str + coefficient: float + + +class CorrelationFormFields(FormFieldsBaseModel): + """ + Model for a list of consumption coefficients for each area. + + Attributes: + correlation: A list of non-null correlation coefficients in percentage. + """ + + correlation: List[CorrelationField] + + # noinspection PyMethodParameters + @validator("correlation") + def check_correlation( + cls, correlation: List[CorrelationField] + ) -> List[CorrelationField]: + if not correlation: + raise ValueError("correlation must not be empty") + counter = collections.Counter(field.area_id for field in correlation) + if duplicates := {id_ for id_, count in counter.items() if count > 1}: + raise ValueError( + f"correlation must not contain duplicate area IDs: {duplicates}" + ) + # fmt: off + array = np.array([a.coefficient for a in correlation], dtype=np.float64) + if np.any((array < -100) | np.any(array > 100)): + raise ValueError("percentage must be between -100 and 100") + if np.any(np.isnan(array)): + raise ValueError("correlation matrix must not contain NaN coefficients") + # fmt: on + return correlation + + +class CorrelationMatrix(FormFieldsBaseModel): + """ + Correlation matrix for hydraulic, wind, load, or solar generators. + + Attributes: + index: A list of all study areas. + columns: A list of selected production areas. + data: A 2D-array matrix of correlation coefficients. + """ + + index: conlist(str, min_items=1) # type: ignore + columns: conlist(str, min_items=1) # type: ignore + data: List[List[float]] # NonNegativeFloat not necessary + + # noinspection PyMethodParameters + @validator("data") + def validate_correlation_matrix( + cls, data: List[List[float]], values: Dict[str, List[str]] + ) -> List[List[float]]: + """ + Validates the correlation matrix by checking its shape and range of coefficients. + + Args: + cls: The `CorrelationMatrix` class. + data: The correlation matrix to validate. + values: A dictionary containing the values of `index` and `columns`. + + Returns: + List[List[float]]: The validated correlation matrix. + + Raises: + ValueError: + If the correlation matrix is empty, + has an incorrect shape, + is squared but not symmetric, + or contains coefficients outside the range of -1 to 1 + or NaN coefficients. + """ + + array = np.array(data) + rows = len(values.get("index", [])) + cols = len(values.get("columns", [])) + + # fmt: off + if array.size == 0: + raise ValueError("correlation matrix must not be empty") + if array.shape != (rows, cols): + raise ValueError(f"correlation matrix must have shape ({rows}×{cols})") + if np.any((array < -1) | np.any(array > 1)): + raise ValueError("coefficients must be between -1 and 1") + if np.any(np.isnan(array)): + raise ValueError("correlation matrix must not contain NaN coefficients") + if ( + array.shape[0] == array.shape[1] + and not np.array_equal(array, array.T) + ): + raise ValueError("correlation matrix is not symmetric") + # fmt: on + return data + + class Config: + schema_extra = { + "example": { + "columns": ["north", "east", "south", "west"], + "data": [ + [0.0, 0.0, 0.25, 0.0], + [0.0, 0.0, 0.75, 0.12], + [0.25, 0.75, 0.0, 0.75], + [0.0, 0.12, 0.75, 0.0], + ], + "index": ["north", "east", "south", "west"], + } + } + + +def _config_to_array( + area_ids: Sequence[str], + correlation_cfg: Dict[str, str], +) -> npt.NDArray[np.float64]: + array = np.identity(len(area_ids), dtype=np.float64) + for key, value in correlation_cfg.items(): + a1, a2 = key.split("%") + i = area_ids.index(a1) + j = area_ids.index(a2) + if i == j: + # ignored: values from the diagonal are always == 1.0 + continue + coefficient = value + array[i][j] = coefficient + array[j][i] = coefficient + return array + + +def _array_to_config( + area_ids: Sequence[str], + array: npt.NDArray[np.float64], +) -> Dict[str, str]: + correlation_cfg: Dict[str, str] = {} + count = len(area_ids) + for i in range(count): + # not saved: values from the diagonal are always == 1.0 + for j in range(i + 1, count): + coefficient = array[i][j] + if not coefficient: + # null values are not saved + continue + a1 = area_ids[i] + a2 = area_ids[j] + correlation_cfg[f"{a1}%{a2}"] = coefficient + return correlation_cfg + + +class CorrelationManager: + """ + This manager allows you to read and write the hydraulic, wind, load or solar + correlation matrices of a raw study or a variant. + """ + + # Today, only the 'hydro' category is fully supported, but + # we could also manage the 'load' 'solar' and 'wind' + # categories but the usage is deprecated. + url = ["input", "hydro", "prepro", "correlation", "annual"] + + def __init__(self, storage_service: StudyStorageService) -> None: + self.storage_service = storage_service + + def _get_array( + self, + file_study: FileStudy, + area_ids: Sequence[str], + ) -> npt.NDArray[np.float64]: + correlation_cfg = file_study.tree.get(self.url, depth=3) + return _config_to_array(area_ids, correlation_cfg) + + def _set_array( + self, + study: Study, + file_study: FileStudy, + area_ids: Sequence[str], + array: npt.NDArray[np.float64], + ) -> None: + correlation_cfg = _array_to_config(area_ids, array) + command_context = ( + self.storage_service.variant_study_service.command_factory.command_context + ) + command = UpdateConfig( + target="/".join(self.url), + data=correlation_cfg, + command_context=command_context, + ) + execute_or_add_commands( + study, file_study, [command], self.storage_service + ) + + def get_correlation_form_fields( + self, all_areas: List[AreaInfoDTO], study: Study, area_id: str + ) -> CorrelationFormFields: + """ + Get the correlation form fields (percentage values) for a given area. + + Args: + all_areas: list of all areas in the study. + study: study to get the correlation coefficients from. + area_id: area to get the correlation coefficients from. + + Returns: + The correlation coefficients. + """ + file_study = self.storage_service.get_storage(study).get_raw(study) + area_ids = [area.id for area in all_areas] + array = self._get_array(file_study, area_ids) + column = array[:, area_ids.index(area_id)] * 100 + return CorrelationFormFields.construct( + correlation=[ + CorrelationField.construct(area_id=a, coefficient=c) + for a, c in zip(area_ids, column) + if c + ] + ) + + def set_correlation_form_fields( + self, + all_areas: List[AreaInfoDTO], + study: Study, + area_id: str, + data: CorrelationFormFields, + ) -> CorrelationFormFields: + """ + Set the correlation coefficients of a given area from the form fields (percentage values). + + Args: + all_areas: list of all areas in the study. + study: study to set the correlation coefficients to. + area_id: area to set the correlation coefficients to. + data: correlation coefficients to set. + + Raises: + AreaNotFound: if the area is not found or invalid. + + Returns: + The updated correlation coefficients. + """ + correlation_ids = {field.area_id for field in data.correlation} + area_ids = [area.id for area in all_areas] + if invalid_ids := correlation_ids - set(area_ids): + # sort for deterministic error message and testing + raise AreaNotFound(*sorted(invalid_ids)) + + file_study = self.storage_service.get_storage(study).get_raw(study) + array = self._get_array(file_study, area_ids) + for field in data.correlation: + i = area_ids.index(field.area_id) + j = area_ids.index(area_id) + array[i][j] = field.coefficient / 100 + self._set_array(study, file_study, area_ids, array) + + column = array[:, area_ids.index(area_id)] * 100 + return CorrelationFormFields.construct( + correlation=[ + CorrelationField.construct(area_id=a, coefficient=c) + for a, c in zip(area_ids, column) + if c + ] + ) + + def get_correlation_matrix( + self, all_areas: List[AreaInfoDTO], study: Study, columns: List[str] + ) -> CorrelationMatrix: + """ + Read the correlation coefficients and get the correlation matrix (values in the range -1 to 1). + + Args: + all_areas: list of all areas in the study. + study: study to get the correlation matrix from. + columns: areas to get the correlation matrix from. + + Returns: + The correlation matrix. + """ + file_study = self.storage_service.get_storage(study).get_raw(study) + area_ids = [area.id for area in all_areas] + columns = ( + [a for a in area_ids if a in columns] if columns else area_ids + ) + array = self._get_array(file_study, area_ids) + # noinspection PyTypeChecker + data = [ + [c for i, c in enumerate(row) if area_ids[i] in columns] + for row in array.tolist() + ] + + return CorrelationMatrix.construct( + index=area_ids, columns=columns, data=data + ) + + def set_correlation_matrix( + self, + all_areas: List[AreaInfoDTO], + study: Study, + matrix: CorrelationMatrix, + ) -> CorrelationMatrix: + """ + Set the correlation coefficients from the coefficient matrix (values in the range -1 to 1). + + Args: + all_areas: list of all areas in the study. + study: study to get the correlation matrix from. + matrix: correlation matrix to update + + Returns: + The updated correlation matrix. + """ + file_study = self.storage_service.get_storage(study).get_raw(study) + area_ids = [area.id for area in all_areas] + + array = self._get_array(file_study, area_ids) + + for row, a1 in zip(matrix.data, matrix.index): + for coefficient, a2 in zip(row, matrix.columns): + if missing := {a1, a2} - set(area_ids): + raise AreaNotFound(*missing) + i = area_ids.index(a1) + j = area_ids.index(a2) + array[i][j] = coefficient + array[j][i] = coefficient + + self._set_array(study, file_study, area_ids, array) + + # noinspection PyTypeChecker + data = [ + [c for i, c in enumerate(row) if area_ids[i] in matrix.columns] + for row in array.tolist() + ] + + return CorrelationMatrix.construct( + index=area_ids, columns=matrix.columns, data=data + ) diff --git a/antarest/study/storage/variantstudy/model/command/create_area.py b/antarest/study/storage/variantstudy/model/command/create_area.py index 2cc832c612..cd029cbed1 100644 --- a/antarest/study/storage/variantstudy/model/command/create_area.py +++ b/antarest/study/storage/variantstudy/model/command/create_area.py @@ -91,16 +91,12 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: area_id = data["area_id"] version = study_data.config.version + # fmt: off hydro_config = study_data.tree.get(["input", "hydro", "hydro"]) - get_or_create_section(hydro_config, "inter-daily-breakdown")[ - area_id - ] = 1 - get_or_create_section(hydro_config, "intra-daily-modulation")[ - area_id - ] = 24 - get_or_create_section(hydro_config, "inter-monthly-breakdown")[ - area_id - ] = 1 + get_or_create_section(hydro_config, "inter-daily-breakdown")[area_id] = 1 + get_or_create_section(hydro_config, "intra-daily-modulation")[area_id] = 24 + get_or_create_section(hydro_config, "inter-monthly-breakdown")[area_id] = 1 + # fmt: on new_area_data: JSON = { "input": { @@ -229,30 +225,22 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: } if version > 650: - get_or_create_section(hydro_config, "initialize reservoir date")[ - area_id - ] = 0 + # fmt: off + get_or_create_section(hydro_config, "initialize reservoir date")[area_id] = 0 get_or_create_section(hydro_config, "leeway low")[area_id] = 1 get_or_create_section(hydro_config, "leeway up")[area_id] = 1 - get_or_create_section(hydro_config, "pumping efficiency")[ - area_id - ] = 1 + get_or_create_section(hydro_config, "pumping efficiency")[area_id] = 1 - new_area_data["input"]["hydro"]["common"]["capacity"][ - f"creditmodulations_{area_id}" - ] = ( + new_area_data["input"]["hydro"]["common"]["capacity"][f"creditmodulations_{area_id}"] = ( self.command_context.generator_matrix_constants.get_hydro_credit_modulations() ) - new_area_data["input"]["hydro"]["common"]["capacity"][ - f"inflowPattern_{area_id}" - ] = ( + new_area_data["input"]["hydro"]["common"]["capacity"][f"inflowPattern_{area_id}"] = ( self.command_context.generator_matrix_constants.get_hydro_inflow_pattern() ) - new_area_data["input"]["hydro"]["common"]["capacity"][ - f"waterValues_{area_id}" - ] = ( + new_area_data["input"]["hydro"]["common"]["capacity"][f"waterValues_{area_id}"] = ( self.command_context.generator_matrix_constants.get_null_matrix() ) + # fmt: on if version >= 830: new_area_data["input"]["areas"][area_id]["adequacy_patch"] = { @@ -261,6 +249,15 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: new_area_data["input"]["hydro"]["hydro"] = hydro_config + # NOTE regarding the following configurations: + # - ["input", "hydro", "prepro", "correlation"] + # - ["input", "load", "prepro", "correlation"] + # - ["input", "solar", "prepro", "correlation"] + # - ["input", "wind", "prepro", "correlation"] + # When creating a new area, we should not add a new correlation + # value to the configuration because it does not store the values + # of the diagonal (always equal to 1). + study_data.tree.save(new_area_data) return output diff --git a/antarest/study/storage/variantstudy/model/command/remove_area.py b/antarest/study/storage/variantstudy/model/command/remove_area.py index d02c17c639..ca633b0d2e 100644 --- a/antarest/study/storage/variantstudy/model/command/remove_area.py +++ b/antarest/study/storage/variantstudy/model/command/remove_area.py @@ -141,6 +141,32 @@ def _remove_area_from_hydro_allocation( allocations.pop(self.id, None) study_data.tree.save(allocation_cfg, ["input", "hydro", "allocation"]) + def _remove_area_from_correlation_matrices( + self, study_data: FileStudy + ) -> None: + """ + Removes the values from the correlation matrix that match the current area. + + This function can update the following configurations: + - ["input", "hydro", "prepro", "correlation"] + + Args: + study_data:File Study to update. + """ + # Today, only the 'hydro' category is fully supported, but + # we could also manage the 'load' 'solar' and 'wind' + # categories but the usage is deprecated. + url = ["input", "hydro", "prepro", "correlation"] + correlation_cfg = study_data.tree.get(url) + for section, correlation in correlation_cfg.items(): + if section == "general": + continue + for key in list(correlation): + a1, a2 = key.split("%") + if a1 == self.id or a2 == self.id: + del correlation[key] + study_data.tree.save(correlation_cfg, url) + def _remove_area_from_districts(self, study_data: FileStudy) -> None: districts = study_data.tree.get(["input", "areas", "sets"]) for district in districts.values(): @@ -197,6 +223,7 @@ def _apply(self, study_data: FileStudy) -> CommandOutput: self._remove_area_from_links(study_data) self._remove_area_from_binding_constraints(study_data) + self._remove_area_from_correlation_matrices(study_data) self._remove_area_from_hydro_allocation(study_data) self._remove_area_from_districts(study_data) self._remove_area_from_cluster(study_data) diff --git a/antarest/study/web/study_data_blueprint.py b/antarest/study/web/study_data_blueprint.py index 538f2c7f07..8b82df0d40 100644 --- a/antarest/study/web/study_data_blueprint.py +++ b/antarest/study/web/study_data_blueprint.py @@ -2,9 +2,6 @@ from http import HTTPStatus from typing import Any, Dict, List, Optional, Union, cast -from fastapi import APIRouter, Body, Depends -from fastapi.params import Body - from antarest.core.config import Config from antarest.core.jwt import JWTUser from antarest.core.model import StudyPermissionType @@ -35,6 +32,11 @@ ConstraintTermDTO, UpdateBindingConstProps, ) +from antarest.study.business.correlation_management import ( + CorrelationFormFields, + CorrelationManager, + CorrelationMatrix, +) from antarest.study.business.district_manager import ( DistrictCreationDTO, DistrictInfoDTO, @@ -61,6 +63,8 @@ from antarest.study.business.timeseries_config_management import TSFormFields from antarest.study.model import PatchArea, PatchCluster from antarest.study.service import StudyService +from fastapi import APIRouter, Body, Depends +from fastapi.params import Body, Query logger = logging.getLogger(__name__) @@ -1142,6 +1146,191 @@ def set_allocation_form_fields( all_areas, study, area_id, data ) + @bp.get( + path="/studies/{uuid}/areas/hydro/correlation/matrix", + tags=[APITag.study_data], + summary="Get the hydraulic/load/solar/wind correlation matrix of a study", + response_model=CorrelationMatrix, + ) + def get_correlation_matrix( + uuid: str, + columns: Optional[str] = Query( + None, + examples={ + "all areas": { + "description": "get the correlation matrix for all areas (by default)", + "value": "", + }, + "single area": { + "description": "get the correlation column for a single area", + "value": "north", + }, + "selected areas": { + "description": "get the correlation columns for a selected list of areas", + "value": "north,east", + }, + }, + ), # type: ignore + current_user: JWTUser = Depends(auth.get_current_user), + ) -> CorrelationMatrix: + """ + Get the hydraulic/load/solar/wind correlation matrix of a study. + + Parameters: + - `uuid`: The UUID of the study. + - `columns`: A filter on the area identifiers: + - Use no parameter to select all areas. + - Use an area identifier to select a single area. + - Use a comma-separated list of areas to select those areas. + + Returns the hydraulic/load/solar/wind correlation matrix with the following attributes: + - `index`: A list of all study areas. + - `columns`: A list of selected production areas. + - `data`: A 2D-array matrix of correlation coefficients with values in the range of -1 to 1. + """ + params = RequestParameters(user=current_user) + study = study_service.check_study_access( + uuid, StudyPermissionType.READ, params + ) + all_areas = cast( + List[AreaInfoDTO], # because `ui=False` + study_service.get_all_areas( + uuid, area_type=AreaType.AREA, ui=False, params=params + ), + ) + manager = CorrelationManager(study_service.storage_service) + return manager.get_correlation_matrix( + all_areas, + study, + columns.split(",") if columns else [], + ) + + @bp.put( + path="/studies/{uuid}/areas/hydro/correlation/matrix", + tags=[APITag.study_data], + summary="Set the hydraulic/load/solar/wind correlation matrix of a study", + status_code=HTTPStatus.OK, + response_model=CorrelationMatrix, + ) + def set_correlation_matrix( + uuid: str, + matrix: CorrelationMatrix = Body( + ..., + example={ + "columns": ["north", "east", "south", "west"], + "data": [ + [0.0, 0.0, 0.25, 0.0], + [0.0, 0.0, 0.75, 0.12], + [0.25, 0.75, 0.0, 0.75], + [0.0, 0.12, 0.75, 0.0], + ], + "index": ["north", "east", "south", "west"], + }, + ), + current_user: JWTUser = Depends(auth.get_current_user), + ) -> CorrelationMatrix: + """ + Set the hydraulic/load/solar/wind correlation matrix of a study. + + Parameters: + - `uuid`: The UUID of the study. + - `index`: A list of all study areas. + - `columns`: A list of selected production areas. + - `data`: A 2D-array matrix of correlation coefficients with values in the range of -1 to 1. + + Returns the hydraulic/load/solar/wind correlation matrix updated + """ + params = RequestParameters(user=current_user) + study = study_service.check_study_access( + uuid, StudyPermissionType.WRITE, params + ) + all_areas = cast( + List[AreaInfoDTO], # because `ui=False` + study_service.get_all_areas( + uuid, area_type=AreaType.AREA, ui=False, params=params + ), + ) + manager = CorrelationManager(study_service.storage_service) + return manager.set_correlation_matrix(all_areas, study, matrix) + + @bp.get( + path="/studies/{uuid}/areas/{area_id}/hydro/correlation/form", + tags=[APITag.study_data], + summary="Get the form fields used for the correlation form", + response_model=CorrelationFormFields, + ) + def get_correlation_form_fields( + uuid: str, + area_id: str, + current_user: JWTUser = Depends(auth.get_current_user), + ) -> CorrelationFormFields: + """ + Get the form fields used for the correlation form. + + Parameters: + - `uuid`: The UUID of the study. + - `area_id`: the area ID. + + Returns the correlation form fields in percentage. + """ + params = RequestParameters(user=current_user) + study = study_service.check_study_access( + uuid, StudyPermissionType.READ, params + ) + all_areas = cast( + List[AreaInfoDTO], # because `ui=False` + study_service.get_all_areas( + uuid, area_type=AreaType.AREA, ui=False, params=params + ), + ) + manager = CorrelationManager(study_service.storage_service) + return manager.get_correlation_form_fields(all_areas, study, area_id) + + @bp.put( + path="/studies/{uuid}/areas/{area_id}/hydro/correlation/form", + tags=[APITag.study_data], + summary="Set the form fields used for the correlation form", + status_code=HTTPStatus.OK, + response_model=CorrelationFormFields, + ) + def set_correlation_form_fields( + uuid: str, + area_id: str, + data: CorrelationFormFields = Body( + ..., + example=CorrelationFormFields( + correlation=[ + {"areaId": "east", "coefficient": 80}, + {"areaId": "north", "coefficient": 20}, + ] + ), + ), + current_user: JWTUser = Depends(auth.get_current_user), + ) -> CorrelationFormFields: + """ + Update the hydraulic/load/solar/wind correlation of a given area. + + Parameters: + - `uuid`: The UUID of the study. + - `area_id`: the area ID. + + Returns the correlation form fields in percentage. + """ + params = RequestParameters(user=current_user) + study = study_service.check_study_access( + uuid, StudyPermissionType.WRITE, params + ) + all_areas = cast( + List[AreaInfoDTO], # because `ui=False` + study_service.get_all_areas( + uuid, area_type=AreaType.AREA, ui=False, params=params + ), + ) + manager = CorrelationManager(study_service.storage_service) + return manager.set_correlation_form_fields( + all_areas, study, area_id, data + ) + @bp.get( path="/studies/{uuid}/config/advancedparameters/form", tags=[APITag.study_data], diff --git a/examples/studies/STA-mini.zip b/examples/studies/STA-mini.zip index c033602313..0de2e32809 100644 Binary files a/examples/studies/STA-mini.zip and b/examples/studies/STA-mini.zip differ diff --git a/tests/integration/study_data_blueprint/test_hydro_correlation.py b/tests/integration/study_data_blueprint/test_hydro_correlation.py new file mode 100644 index 0000000000..83aac59495 --- /dev/null +++ b/tests/integration/study_data_blueprint/test_hydro_correlation.py @@ -0,0 +1,253 @@ +from http import HTTPStatus +from typing import List + +import pytest +from antarest.study.business.area_management import AreaInfoDTO +from starlette.testclient import TestClient + + +@pytest.mark.unit_test +class TestHydroCorrelation: + """ + Test the end points related to hydraulic correlation. + + Those tests use the "examples/studies/STA-mini.zip" Study, + which contains the following areas: ["de", "es", "fr", "it"]. + """ + + def test_get_correlation_form_values( + self, + client: TestClient, + user_access_token: str, + study_id: str, + ): + """Check `get_correlation_form_values` end point""" + area_id = "fr" + res = client.get( + f"/v1/studies/{study_id}/areas/{area_id}/hydro/correlation/form", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == HTTPStatus.OK, res.json() + actual = res.json() + expected = { + "correlation": [ + {"areaId": "de", "coefficient": 25.0}, + {"areaId": "es", "coefficient": 75.0}, + {"areaId": "fr", "coefficient": 100.0}, + {"areaId": "it", "coefficient": 75.0}, + ] + } + assert actual == expected + + def test_set_correlation_form_values( + self, + client: TestClient, + user_access_token: str, + study_id: str, + ): + """Check `set_correlation_form_values` end point""" + area_id = "fr" + obj = { + "correlation": [ + {"areaId": "de", "coefficient": 20}, + {"areaId": "es", "coefficient": -80}, + {"areaId": "it", "coefficient": 0}, + ] + } + res = client.put( + f"/v1/studies/{study_id}/areas/{area_id}/hydro/correlation/form", + headers={"Authorization": f"Bearer {user_access_token}"}, + json=obj, + ) + assert res.status_code == HTTPStatus.OK, res.json() + actual = res.json() + expected = { + "correlation": [ + {"areaId": "de", "coefficient": 20.0}, + {"areaId": "es", "coefficient": -80.0}, + {"areaId": "fr", "coefficient": 100.0}, + ] + } + assert actual == expected + + @pytest.mark.parametrize( + "columns, expected", + [ + pytest.param( + "", + { + "columns": ["de", "es", "fr", "it"], + "data": [ + [1.0, 0.0, 0.25, 0.0], + [0.0, 1.0, 0.75, 0.12], + [0.25, 0.75, 1.0, 0.75], + [0.0, 0.12, 0.75, 1.0], + ], + "index": ["de", "es", "fr", "it"], + }, + id="all-areas", + ), + pytest.param( + "fr,de", + { + "columns": ["de", "fr"], + "data": [ + [1.0, 0.25], + [0.0, 0.75], + [0.25, 1.0], + [0.0, 0.75], + ], + "index": ["de", "es", "fr", "it"], + }, + id="some-areas", + ), + pytest.param( + "fr", + { + "columns": ["fr"], + "data": [ + [0.25], + [0.75], + [1.0], + [0.75], + ], + "index": ["de", "es", "fr", "it"], + }, + id="one-area", + ), + ], + ) + def test_get_correlation_matrix( + self, + client: TestClient, + user_access_token: str, + study_id: str, + columns: str, + expected: List[List[float]], + ): + """Check `get_correlation_matrix` end point""" + query = f"columns={columns}" if columns else "" + res = client.get( + f"/v1/studies/{study_id}/areas/hydro/correlation/matrix?{query}", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == HTTPStatus.OK, res.json() + actual = res.json() + assert actual == expected + + def test_set_correlation_matrix( + self, + client: TestClient, + user_access_token: str, + study_id: str, + ): + """Check `set_correlation_matrix` end point""" + obj = { + "columns": ["fr", "it"], + "data": [ + [-0.79332875, -0.96830414], + [-0.23220568, -0.158783], + [1.0, 0.82], + [0.82, 1.0], + ], + "index": ["de", "es", "fr", "it"], + } + res = client.put( + f"/v1/studies/{study_id}/areas/hydro/correlation/matrix", + headers={"Authorization": f"Bearer {user_access_token}"}, + json=obj, + ) + assert res.status_code == HTTPStatus.OK, res.json() + actual = res.json() + expected = obj + assert actual == expected + + def test_create_area( + self, client: TestClient, user_access_token: str, study_id: str + ): + """ + Given a study, when an area is created, the hydraulic correlation + column for this area must be updated with the following values: + - the coefficient == 1 for this area, + - the coefficient == 0 for the other areas. + Other columns must not be changed. + """ + area_info = AreaInfoDTO(id="north", name="NORTH", type="AREA") + res = client.post( + f"/v1/studies/{study_id}/areas", + headers={"Authorization": f"Bearer {user_access_token}"}, + data=area_info.json(), + ) + assert res.status_code == HTTPStatus.OK, res.json() + + res = client.get( + f"/v1/studies/{study_id}/areas/hydro/correlation/matrix", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == HTTPStatus.OK + actual = res.json() + expected = { + "columns": ["de", "es", "fr", "it", "north"], + "data": [ + [1.0, 0.0, 0.25, 0.0, 0.0], + [0.0, 1.0, 0.75, 0.12, 0.0], + [0.25, 0.75, 1.0, 0.75, 0.0], + [0.0, 0.12, 0.75, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + "index": ["de", "es", "fr", "it", "north"], + } + assert actual == expected + + def test_delete_area( + self, client: TestClient, user_access_token: str, study_id: str + ): + """ + Given a study, when an area is deleted, the hydraulic correlation + column for this area must be removed. + Other columns must be updated to reflect the area deletion. + """ + # First change the coefficients to avoid zero values (which are defaults). + correlation_cfg = { + "annual": { + "de%es": 0.12, + "de%fr": 0.13, + "de%it": 0.14, + "es%fr": 0.22, + "es%it": 0.23, + "fr%it": 0.32, + } + } + res = client.post( + f"/v1/studies/{study_id}/raw?path=input/hydro/prepro/correlation", + headers={"Authorization": f"Bearer {user_access_token}"}, + json=correlation_cfg, + ) + assert res.status_code == HTTPStatus.NO_CONTENT, res.json() + + # Then we remove the "fr" zone. + # The deletion should update the correlation matrix of all other zones. + res = client.delete( + f"/v1/studies/{study_id}/areas/fr", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == HTTPStatus.OK, res.json() + + # Check that the "fr" column is removed from the hydraulic correlation matrix. + # The row corresponding to "fr" must also be deleted. + res = client.get( + f"/v1/studies/{study_id}/areas/hydro/correlation/matrix", + headers={"Authorization": f"Bearer {user_access_token}"}, + ) + assert res.status_code == HTTPStatus.OK, res.json() + actual = res.json() + expected = { + "columns": ["de", "es", "it"], + "data": [ + [1.0, 0.12, 0.14], + [0.12, 1.0, 0.23], + [0.14, 0.23, 1.0], + ], + "index": ["de", "es", "it"], + } + assert actual == expected diff --git a/tests/storage/business/test_arealink_manager.py b/tests/storage/business/test_arealink_manager.py index 23fcf9be04..cae6323b49 100644 --- a/tests/storage/business/test_arealink_manager.py +++ b/tests/storage/business/test_arealink_manager.py @@ -1,12 +1,10 @@ import json -import os import uuid from pathlib import Path from unittest.mock import Mock from zipfile import ZipFile import pytest - from antarest.core.jwt import DEFAULT_ADMIN_USER from antarest.core.requests import RequestParameters from antarest.matrixstore.service import ( @@ -14,36 +12,34 @@ SimpleMatrixService, ) from antarest.study.business.area_management import ( + AreaCreationDTO, AreaManager, AreaType, - AreaCreationDTO, AreaUI, ) -from antarest.study.business.link_management import LinkManager, LinkInfoDTO +from antarest.study.business.link_management import LinkInfoDTO, LinkManager from antarest.study.model import ( - RawStudy, Patch, PatchArea, PatchCluster, + RawStudy, StudyAdditionalData, ) from antarest.study.repository import StudyMetadataRepository from antarest.study.storage.patch_service import PatchService from antarest.study.storage.rawstudy.model.filesystem.config.files import build from antarest.study.storage.rawstudy.model.filesystem.config.model import ( - FileStudyTreeConfig, Area, + Cluster, DistrictSet, + FileStudyTreeConfig, Link, - Cluster, ) from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy from antarest.study.storage.rawstudy.model.filesystem.root.filestudytree import ( FileStudyTree, ) -from antarest.study.storage.rawstudy.raw_study_service import ( - RawStudyService, -) +from antarest.study.storage.rawstudy.raw_study_service import RawStudyService from antarest.study.storage.storage_service import StudyStorageService from antarest.study.storage.variantstudy.business.matrix_constants_generator import ( GeneratorMatrixConstants, @@ -57,14 +53,13 @@ from antarest.study.storage.variantstudy.variant_study_service import ( VariantStudyService, ) -from tests.conftest import with_db_context @pytest.fixture -def empty_study(tmpdir: Path) -> FileStudy: +def empty_study(tmp_path: Path) -> FileStudy: cur_dir: Path = Path(__file__).parent - study_path = Path(tmpdir / str(uuid.uuid4())) - os.mkdir(study_path) + study_path = tmp_path.joinpath(str(uuid.uuid4())) + study_path.mkdir() with ZipFile(cur_dir / "assets" / "empty_study_810.zip") as zip_output: zip_output.extractall(path=study_path) config = build(study_path, "1") @@ -72,9 +67,9 @@ def empty_study(tmpdir: Path) -> FileStudy: @pytest.fixture -def matrix_service(tmpdir: Path) -> ISimpleMatrixService: - matrix_path = Path(tmpdir / "matrix_store") - os.mkdir(matrix_path) +def matrix_service(tmp_path: Path) -> ISimpleMatrixService: + matrix_path = tmp_path.joinpath("matrix_store") + matrix_path.mkdir() return SimpleMatrixService(matrix_path) @@ -94,6 +89,7 @@ def test_area_crud( raw_study_service, variant_study_service ) ) + # noinspection PyArgumentList study = RawStudy( id="1", path=empty_study.config.study_path, @@ -143,6 +139,7 @@ def test_area_crud( area_manager.delete_area(study, "test2") assert len(empty_study.config.areas.keys()) == 0 + # noinspection PyArgumentList study = VariantStudy( id="2", path=empty_study.config.study_path, @@ -421,8 +418,6 @@ def test_get_all_area(): {"area1": "a2", "area2": "a3", "ui": None}, ] == [link.dict() for link in links] - pass - def test_update_area(): raw_study_service = Mock(spec=RawStudyService) @@ -523,4 +518,4 @@ def test_update_clusters(): ) assert len(new_area_info.thermals) == 1 assert new_area_info.thermals[0].type == "a" - assert new_area_info.thermals[0].code_oi == None + assert new_area_info.thermals[0].code_oi is None diff --git a/tests/study/business/test_correlation_manager.py b/tests/study/business/test_correlation_manager.py new file mode 100644 index 0000000000..29c07c393a --- /dev/null +++ b/tests/study/business/test_correlation_manager.py @@ -0,0 +1,397 @@ +import contextlib +import datetime +import uuid +from unittest.mock import Mock, patch + +import numpy as np +import pytest +from antarest.core.exceptions import AreaNotFound +from antarest.core.model import PublicMode +from antarest.dbmodel import Base +from antarest.login.model import Group, User +from antarest.study.business.area_management import AreaInfoDTO, AreaType +from antarest.study.business.correlation_management import ( + CorrelationField, + CorrelationFormFields, + CorrelationManager, + CorrelationMatrix, +) +from antarest.study.model import RawStudy, Study, StudyContentStatus +from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy +from antarest.study.storage.rawstudy.model.filesystem.root.filestudytree import ( + FileStudyTree, +) +from antarest.study.storage.rawstudy.raw_study_service import RawStudyService +from antarest.study.storage.storage_service import StudyStorageService +from antarest.study.storage.variantstudy.command_factory import CommandFactory +from antarest.study.storage.variantstudy.model.command.common import ( + CommandName, +) +from antarest.study.storage.variantstudy.model.command.update_config import ( + UpdateConfig, +) +from antarest.study.storage.variantstudy.model.command_context import ( + CommandContext, +) +from antarest.study.storage.variantstudy.variant_study_service import ( + VariantStudyService, +) +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + + +class TestCorrelationField: + def test_init__nominal_case(self): + field = CorrelationField(area_id="NORTH", coefficient=100) + assert field.area_id == "NORTH" + assert field.coefficient == 100 + + def test_init__camel_case_args(self): + field = CorrelationField(areaId="NORTH", coefficient=100) + assert field.area_id == "NORTH" + assert field.coefficient == 100 + + +class TestCorrelationFormFields: + def test_init__nominal_case(self): + fields = CorrelationFormFields( + correlation=[ + {"area_id": "NORTH", "coefficient": 75}, + {"area_id": "SOUTH", "coefficient": 25}, + ] + ) + assert fields.correlation == [ + CorrelationField(area_id="NORTH", coefficient=75), + CorrelationField(area_id="SOUTH", coefficient=25), + ] + + def test_validation__coefficients_not_empty(self): + """correlation must not be empty""" + with pytest.raises(ValueError, match="must not be empty"): + CorrelationFormFields(correlation=[]) + + def test_validation__coefficients_no_duplicates(self): + """correlation must not contain duplicate area IDs:""" + with pytest.raises(ValueError, match="duplicate area IDs") as ctx: + CorrelationFormFields( + correlation=[ + {"area_id": "NORTH", "coefficient": 50}, + {"area_id": "NORTH", "coefficient": 25}, + {"area_id": "SOUTH", "coefficient": 25}, + ] + ) + assert "NORTH" in str(ctx.value) # duplicates + + @pytest.mark.parametrize("coefficient", [-101, 101, np.nan]) + def test_validation__coefficients_invalid_values(self, coefficient): + """coefficients must be between -100 and 100""" + with pytest.raises( + ValueError, match="between -100 and 100|must not contain NaN" + ): + CorrelationFormFields( + correlation=[ + {"area_id": "NORTH", "coefficient": coefficient}, + ] + ) + + +class TestCorrelationMatrix: + def test_init__nominal_case(self): + field = CorrelationMatrix( + index=["fr", "de"], + columns=["fr"], + data=[ + [1.0], + [0.2], + ], + ) + assert field.index == ["fr", "de"] + assert field.columns == ["fr"] + assert field.data == [ + [1.0], + [0.2], + ] + + def test_validation__coefficients_non_empty_array(self): + """Check that the coefficients matrix is a non-empty array""" + # fmt: off + with pytest.raises(ValueError, match="must not be empty"): + CorrelationMatrix( + index=[], + columns=[], + data=[], + ) + # fmt: off + + def test_validation__coefficients_array_shape(self): + """Check that the coefficients matrix is an array of shape 2×1""" + with pytest.raises(ValueError, match=r"must have shape \(\d+×\d+\)"): + CorrelationMatrix( + index=["fr", "de"], + columns=["fr"], + data=[[1, 2], [3, 4]], + ) + + @pytest.mark.parametrize("coefficient", [-1.1, 1.1, np.nan]) + def test_validation__coefficients_invalid_value(self, coefficient): + """Check that all coefficients matrix has positive or nul coefficients""" + # fmt: off + with pytest.raises(ValueError, match="between -1 and 1|must not contain NaN"): + CorrelationMatrix( + index=["fr", "de"], + columns=["fr", "de"], + data=[ + [1.0, coefficient], + [0.2, 0], + ], + ) + # fmt: on + + def test_validation__matrix_not_symmetric(self): + """Check that the correlation matrix is not symmetric""" + with pytest.raises(ValueError, match=r"not symmetric"): + CorrelationMatrix( + index=["fr", "de"], + columns=["fr", "de"], + data=[[0.1, 0.2], [0.3, 0.4]], + ) + + +@pytest.fixture(scope="function", name="db_engine") +def db_engine_fixture(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + yield engine + engine.dispose() + + +@pytest.fixture(scope="function", name="db_session") +def db_session_fixture(db_engine): + make_session = sessionmaker(bind=db_engine) + with contextlib.closing(make_session()) as session: + yield session + + +# noinspection SpellCheckingInspection +EXECUTE_OR_ADD_COMMANDS = ( + "antarest.study.business.correlation_management.execute_or_add_commands" +) + + +class TestCorrelationManager: + @pytest.fixture(name="study_storage_service") + def study_storage_service(self) -> StudyStorageService: + """Return a mocked StudyStorageService.""" + return Mock( + spec=StudyStorageService, + variant_study_service=Mock( + spec=VariantStudyService, + command_factory=Mock( + spec=CommandFactory, + command_context=Mock(spec=CommandContext), + ), + ), + get_storage=Mock( + return_value=Mock( + spec=RawStudyService, get_raw=Mock(spec=FileStudy) + ) + ), + ) + + # noinspection PyArgumentList + @pytest.fixture(name="study_uuid") + def study_uuid_fixture(self, db_session) -> str: + user = User(id=0, name="admin") + group = Group(id="my-group", name="group") + raw_study = RawStudy( + id=str(uuid.uuid4()), + name="Dummy", + version="850", + author="John Smith", + created_at=datetime.datetime.now(datetime.timezone.utc), + updated_at=datetime.datetime.now(datetime.timezone.utc), + public_mode=PublicMode.FULL, + owner=user, + groups=[group], + workspace="default", + path="/path/to/study", + content_status=StudyContentStatus.WARNING, + ) + db_session.add(raw_study) + db_session.commit() + return raw_study.id + + def test_get_correlation_matrix__nominal_case( + self, db_session, study_storage_service, study_uuid + ): + # The study must be fetched from the database + study: RawStudy = db_session.query(Study).get(study_uuid) + + # Prepare the mocks + correlation_cfg = { + "n%n": 0.1, + "e%e": 0.3, + "s%s": 0.1, + "s%n": 0.2, + "s%w": 0.6, + "w%w": 0.1, + } + storage = study_storage_service.get_storage(study) + file_study = storage.get_raw(study) + file_study.tree = Mock( + spec=FileStudyTree, + get=Mock(return_value=correlation_cfg), + ) + + # Given the following arguments + all_areas = [ + AreaInfoDTO(id="n", name="North", type=AreaType.AREA), + AreaInfoDTO(id="e", name="East", type=AreaType.AREA), + AreaInfoDTO(id="s", name="South", type=AreaType.AREA), + AreaInfoDTO(id="w", name="West", type=AreaType.AREA), + ] + manager = CorrelationManager(study_storage_service) + + # run + matrix = manager.get_correlation_matrix( + all_areas=all_areas, study=study, columns=[] + ) + + # Check + assert matrix == CorrelationMatrix( + index=["n", "e", "s", "w"], + columns=["n", "e", "s", "w"], + data=[ + [1.0, 0.0, 0.2, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.2, 0.0, 1.0, 0.6], + [0.0, 0.0, 0.6, 1.0], + ], + ) + + def test_get_field_values__nominal_case( + self, db_session, study_storage_service, study_uuid + ): + # The study must be fetched from the database + study: RawStudy = db_session.query(Study).get(study_uuid) + + # Prepare the mocks + # NOTE: "s%s" value is ignored + correlation_cfg = {"s%s": 0.1, "n%s": 0.2, "w%n": 0.6} + storage = study_storage_service.get_storage(study) + file_study = storage.get_raw(study) + file_study.tree = Mock( + spec=FileStudyTree, + get=Mock(return_value=correlation_cfg), + ) + + # Given the following arguments + all_areas = [ + AreaInfoDTO(id="n", name="North", type=AreaType.AREA), + AreaInfoDTO(id="e", name="East", type=AreaType.AREA), + AreaInfoDTO(id="s", name="South", type=AreaType.AREA), + AreaInfoDTO(id="w", name="West", type=AreaType.AREA), + ] + area_id = "s" # South + manager = CorrelationManager(study_storage_service) + fields = manager.get_correlation_form_fields( + all_areas=all_areas, study=study, area_id=area_id + ) + assert fields == CorrelationFormFields( + correlation=[ + CorrelationField(area_id="n", coefficient=20.0), + CorrelationField(area_id="s", coefficient=100.0), + ] + ) + + def test_set_field_values__nominal_case( + self, db_session, study_storage_service, study_uuid + ): + # The study must be fetched from the database + study: RawStudy = db_session.query(Study).get(study_uuid) + + # Prepare the mocks: North + South + correlation_cfg = {} + storage = study_storage_service.get_storage(study) + file_study = storage.get_raw(study) + file_study.tree = Mock( + spec=FileStudyTree, + get=Mock(return_value=correlation_cfg), + ) + + # Given the following arguments + all_areas = [ + AreaInfoDTO(id="n", name="North", type=AreaType.AREA), + AreaInfoDTO(id="e", name="East", type=AreaType.AREA), + AreaInfoDTO(id="s", name="South", type=AreaType.AREA), + AreaInfoDTO(id="w", name="West", type=AreaType.AREA), + ] + area_id = "s" # South + manager = CorrelationManager(study_storage_service) + with patch(EXECUTE_OR_ADD_COMMANDS) as exe: + manager.set_correlation_form_fields( + all_areas=all_areas, + study=study, + area_id=area_id, + data=CorrelationFormFields( + correlation=[ + CorrelationField(area_id="s", coefficient=100), + CorrelationField(area_id="e", coefficient=30), + CorrelationField(area_id="n", coefficient=40), + ] + ), + ) + + # check update + assert exe.call_count == 1 + mock_call = exe.mock_calls[0] + # signature: execute_or_add_commands(study, file_study, commands, storage_service) + actual_study, _, actual_cmds, _ = mock_call.args + assert actual_study == study + assert len(actual_cmds) == 1 + cmd: UpdateConfig = actual_cmds[0] + assert cmd.command_name == CommandName.UPDATE_CONFIG + assert cmd.target == "input/hydro/prepro/correlation/annual" + assert cmd.data == {"e%s": 0.3, "n%s": 0.4} + + def test_set_field_values__area_not_found( + self, db_session, study_storage_service, study_uuid + ): + # The study must be fetched from the database + study: RawStudy = db_session.query(Study).get(study_uuid) + + # Prepare the mocks: North + South + correlation_cfg = {} + storage = study_storage_service.get_storage(study) + file_study = storage.get_raw(study) + file_study.tree = Mock( + spec=FileStudyTree, + get=Mock(return_value=correlation_cfg), + ) + + # Given the following arguments + all_areas = [ + AreaInfoDTO(id="n", name="North", type=AreaType.AREA), + AreaInfoDTO(id="e", name="East", type=AreaType.AREA), + AreaInfoDTO(id="s", name="South", type=AreaType.AREA), + AreaInfoDTO(id="w", name="West", type=AreaType.AREA), + ] + area_id = "n" # South + manager = CorrelationManager(study_storage_service) + + with patch(EXECUTE_OR_ADD_COMMANDS) as exe: + with pytest.raises(AreaNotFound) as ctx: + manager.set_correlation_form_fields( + all_areas=all_areas, + study=study, + area_id=area_id, + data=CorrelationFormFields( + correlation=[ + CorrelationField( + area_id="UNKNOWN", coefficient=3.14 + ), + ] + ), + ) + assert "'UNKNOWN'" in ctx.value.detail + exe.assert_not_called() diff --git a/tests/variantstudy/conftest.py b/tests/variantstudy/conftest.py index 77fe89bdf4..e9ad73b7d7 100644 --- a/tests/variantstudy/conftest.py +++ b/tests/variantstudy/conftest.py @@ -51,6 +51,7 @@ def matrix_service() -> MatrixService: @pytest.fixture def command_context(matrix_service: MatrixService) -> CommandContext: + # sourcery skip: inline-immediately-returned-variable command_context = CommandContext( generator_matrix_constants=GeneratorMatrixConstants( matrix_service=matrix_service @@ -75,10 +76,10 @@ def command_factory(matrix_service: MatrixService) -> CommandFactory: @pytest.fixture -def empty_study(tmp_path: str, matrix_service: MatrixService) -> FileStudy: +def empty_study(tmp_path: Path, matrix_service: MatrixService) -> FileStudy: project_dir: Path = Path(__file__).parent.parent.parent empty_study_path: Path = project_dir / "resources" / "empty_study_720.zip" - empty_study_destination_path = Path(tmp_path) / "empty-study" + empty_study_destination_path = tmp_path.joinpath("empty-study") with zipfile.ZipFile(empty_study_path, "r") as zip_empty_study: zip_empty_study.extractall(empty_study_destination_path) @@ -90,6 +91,7 @@ def empty_study(tmp_path: str, matrix_service: MatrixService) -> FileStudy: areas={}, sets={}, ) + # sourcery skip: inline-immediately-returned-variable file_study = FileStudy( config=config, tree=FileStudyTree( diff --git a/tests/variantstudy/model/command/test_remove_area.py b/tests/variantstudy/model/command/test_remove_area.py index 6fd8e646fb..b1f61f0bd9 100644 --- a/tests/variantstudy/model/command/test_remove_area.py +++ b/tests/variantstudy/model/command/test_remove_area.py @@ -1,10 +1,10 @@ -from checksumdir import dirhash +import pytest -from antarest.study.storage.rawstudy.io.reader import IniReader from antarest.study.storage.rawstudy.model.filesystem.config.model import ( transform_name_to_id, ) from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy +from antarest.study.storage.study_upgrader import upgrade_study from antarest.study.storage.variantstudy.model.command.common import ( TimeStep, BindingConstraintOperator, @@ -41,25 +41,14 @@ class TestRemoveArea: - def test_validation(self, empty_study: FileStudy): - pass - + @pytest.mark.parametrize("version", [810, 840]) def test_apply( self, empty_study: FileStudy, command_context: CommandContext, + version: int, ): - bd_config = IniReader().read( - empty_study.config.study_path - / "input" - / "bindingconstraints" - / "bindingconstraints.ini" - ) - - area_name = "Area" - area_id = transform_name_to_id(area_name) - area_name2 = "Area2" - area_id2 = transform_name_to_id(area_name2) + # noinspection SpellCheckingInspection empty_study.tree.save( { "input": { @@ -84,6 +73,8 @@ def test_apply( } ) + area_name = "Area" + area_id = transform_name_to_id(area_name) create_area_command: ICommand = CreateArea.parse_obj( { "area_name": area_name, @@ -93,14 +84,6 @@ def test_apply( output = create_area_command.apply(study_data=empty_study) assert output.status - parameters = { - "group": "Other", - "unitcount": "1", - "nominalcapacity": "1000000", - "marginal-cost": "30", - "market-bid-cost": "30", - } - create_district_command = CreateDistrict( name="foo", base_filter=DistrictBaseFilter.add_all, @@ -112,85 +95,99 @@ def test_apply( ######################################################################################## - empty_study_hash = dirhash(empty_study.config.study_path, "md5") + upgrade_study(empty_study.config.study_path, str(version)) - for version in [810, 840]: - empty_study.config.version = version - create_area_command: ICommand = CreateArea.parse_obj( - { - "area_name": area_name2, - "command_context": command_context, - } - ) - output = create_area_command.apply(study_data=empty_study) - assert output.status - - create_link_command: ICommand = CreateLink( - area1=area_id, - area2=area_id2, - parameters={}, - command_context=command_context, - series=[[0]], - ) - output = create_link_command.apply(study_data=empty_study) - assert output.status - - create_cluster_command = CreateCluster.parse_obj( - { - "area_id": area_id2, - "cluster_name": "cluster", - "parameters": parameters, - "prepro": [[0]], - "modulation": [[0]], - "command_context": command_context, - } - ) - output = create_cluster_command.apply(study_data=empty_study) - assert output.status - - bind1_cmd = CreateBindingConstraint( - name="BD 2", - time_step=TimeStep.HOURLY, - operator=BindingConstraintOperator.LESS, - coeffs={ - f"{area_id}%{area_id2}": [400, 30], - f"{area_id2}.cluster": [400, 30], + empty_study_cfg = empty_study.tree.get(depth=999) + if version >= 830: + empty_study_cfg["input"]["areas"][area_id]["adequacy_patch"] = { + "adequacy-patch": {"adequacy-patch-mode": "outside"} + } + empty_study_cfg["input"]["links"][area_id]["capacities"] = {} + + area_name2 = "Area2" + area_id2 = transform_name_to_id(area_name2) + + empty_study.config.version = version + create_area_command: ICommand = CreateArea.parse_obj( + { + "area_name": area_name2, + "command_context": command_context, + } + ) + output = create_area_command.apply(study_data=empty_study) + assert output.status + + create_link_command: ICommand = CreateLink( + area1=area_id, + area2=area_id2, + parameters={}, + command_context=command_context, + series=[[0]], + ) + output = create_link_command.apply(study_data=empty_study) + assert output.status + + # noinspection SpellCheckingInspection + create_cluster_command = CreateCluster.parse_obj( + { + "area_id": area_id2, + "cluster_name": "cluster", + "parameters": { + "group": "Other", + "unitcount": "1", + "nominalcapacity": "1000000", + "marginal-cost": "30", + "market-bid-cost": "30", }, - comments="Hello", - command_context=command_context, - ) - output = bind1_cmd.apply(study_data=empty_study) - assert output.status - - remove_district_command = RemoveDistrict( - id="foo", - command_context=command_context, - ) - output = remove_district_command.apply(study_data=empty_study) - assert output.status - - create_district_command = CreateDistrict( - name="foo", - base_filter=DistrictBaseFilter.add_all, - filter_items=[area_id, area_id2], - command_context=command_context, - ) - output = create_district_command.apply(study_data=empty_study) - assert output.status - - remove_area_command: ICommand = RemoveArea.parse_obj( - { - "id": transform_name_to_id(area_name2), - "command_context": command_context, - } - ) - output = remove_area_command.apply(study_data=empty_study) - assert output.status - - assert ( - dirhash(empty_study.config.study_path, "md5") - == empty_study_hash - ) + "prepro": [[0]], + "modulation": [[0]], + "command_context": command_context, + } + ) + output = create_cluster_command.apply(study_data=empty_study) + assert output.status + + bind1_cmd = CreateBindingConstraint( + name="BD 2", + time_step=TimeStep.HOURLY, + operator=BindingConstraintOperator.LESS, + coeffs={ + f"{area_id}%{area_id2}": [400, 30], + f"{area_id2}.cluster": [400, 30], + }, + comments="Hello", + command_context=command_context, + ) + output = bind1_cmd.apply(study_data=empty_study) + assert output.status + + remove_district_command = RemoveDistrict( + id="foo", + command_context=command_context, + ) + output = remove_district_command.apply(study_data=empty_study) + assert output.status + + create_district_command = CreateDistrict( + name="foo", + base_filter=DistrictBaseFilter.add_all, + filter_items=[area_id, area_id2], + command_context=command_context, + ) + output = create_district_command.apply(study_data=empty_study) + assert output.status + + remove_area_command: ICommand = RemoveArea.parse_obj( + { + "id": transform_name_to_id(area_name2), + "command_context": command_context, + } + ) + output = remove_area_command.apply(study_data=empty_study) + assert output.status + + actual_cfg = empty_study.tree.get(depth=999) + assert actual_cfg == empty_study_cfg def test_match(command_context: CommandContext):