Skip to content

Commit

Permalink
refactor(services): fix wrong dependency from business to service (#2343
Browse files Browse the repository at this point in the history
)

Signed-off-by: Sylvain Leclerc <[email protected]>
  • Loading branch information
sylvlecl authored Feb 18, 2025
1 parent 6c8b795 commit 0bffadb
Show file tree
Hide file tree
Showing 52 changed files with 2,033 additions and 2,211 deletions.
65 changes: 65 additions & 0 deletions antarest/matrixstore/in_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2025, RTE (https://www.rte-france.com)
#
# See AUTHORS.txt
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
#
# SPDX-License-Identifier: MPL-2.0
#
# This file is part of the Antares project.

import hashlib
import time
from typing import Dict, List, Optional

import numpy as np
import numpy.typing as npt
from typing_extensions import override

from antarest.matrixstore.model import MatrixData, MatrixDTO
from antarest.matrixstore.service import ISimpleMatrixService


class InMemorySimpleMatrixService(ISimpleMatrixService):
"""
In memory implementation of matrix service, for unit testing purposes.
"""

def __init__(self) -> None:
self._content: Dict[str, MatrixDTO] = {}

def _make_dto(self, id: str, matrix: npt.NDArray[np.float64]) -> MatrixDTO:
matrix = matrix.reshape((1, 0)) if matrix.size == 0 else matrix
data = matrix.tolist()
index = [str(i) for i in range(matrix.shape[0])]
columns = [str(i) for i in range(matrix.shape[1])]
return MatrixDTO(
data=data,
index=index,
columns=columns,
id=id,
created_at=int(time.time()),
width=len(columns),
height=len(index),
)

@override
def create(self, data: List[List[MatrixData]] | npt.NDArray[np.float64]) -> str:
matrix = data if isinstance(data, np.ndarray) else np.array(data, dtype=np.float64)
matrix_hash = hashlib.sha256(matrix.data).hexdigest()
self._content[matrix_hash] = self._make_dto(matrix_hash, matrix)
return matrix_hash

@override
def get(self, matrix_id: str) -> Optional[MatrixDTO]:
return self._content.get(matrix_id, None)

@override
def exists(self, matrix_id: str) -> bool:
return matrix_id in self._content

@override
def delete(self, matrix_id: str) -> None:
del self._content[matrix_id]
24 changes: 12 additions & 12 deletions antarest/study/business/adequacy_patch_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

from antarest.study.business.all_optional_meta import all_optional_model
from antarest.study.business.enum_ignore_case import EnumIgnoreCase
from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel, execute_or_add_commands
from antarest.study.model import STUDY_VERSION_8_3, STUDY_VERSION_8_5, Study
from antarest.study.storage.storage_service import StudyStorageService
from antarest.study.business.study_interface import StudyInterface
from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel
from antarest.study.model import STUDY_VERSION_8_3, STUDY_VERSION_8_5
from antarest.study.storage.variantstudy.model.command.update_config import UpdateConfig
from antarest.study.storage.variantstudy.model.command_context import CommandContext


class PriceTakingOrder(EnumIgnoreCase):
Expand Down Expand Up @@ -98,28 +99,28 @@ class AdequacyPatchFormFields(FormFieldsBaseModel):


class AdequacyPatchManager:
def __init__(self, storage_service: StudyStorageService) -> None:
self.storage_service = storage_service
def __init__(self, command_context: CommandContext) -> None:
self._command_context = command_context

def get_field_values(self, study: Study) -> AdequacyPatchFormFields:
def get_field_values(self, study: StudyInterface) -> AdequacyPatchFormFields:
"""
Get adequacy patch field values for the webapp form
"""
file_study = self.storage_service.get_storage(study).get_raw(study)
file_study = study.get_files()
general_data = file_study.tree.get(GENERAL_DATA_PATH.split("/"))
parent = general_data.get("adequacy patch", {})

def get_value(field_info: FieldInfo) -> Any:
path = field_info["path"]
start_version = field_info.get("start_version", -1)
target_name = path.split("/")[-1]
is_in_version = file_study.config.version >= start_version
is_in_version = study.version >= start_version

return parent.get(target_name, field_info["default_value"]) if is_in_version else None

return AdequacyPatchFormFields.model_construct(**{name: get_value(info) for name, info in FIELDS_INFO.items()})

def set_field_values(self, study: Study, field_values: AdequacyPatchFormFields) -> None:
def set_field_values(self, study: StudyInterface, field_values: AdequacyPatchFormFields) -> None:
"""
Set adequacy patch config from the webapp form
"""
Expand All @@ -133,11 +134,10 @@ def set_field_values(self, study: Study, field_values: AdequacyPatchFormFields)
UpdateConfig(
target=info["path"],
data=value,
command_context=self.storage_service.variant_study_service.command_factory.command_context,
command_context=self._command_context,
study_version=study.version,
)
)

if commands:
file_study = self.storage_service.get_storage(study).get_raw(study)
execute_or_add_commands(study, file_study, commands, self.storage_service)
study.add_commands(commands)
25 changes: 12 additions & 13 deletions antarest/study/business/advanced_parameters_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@

from typing import Any, Dict, List

from antares.study.version import StudyVersion
from pydantic import field_validator
from pydantic.types import StrictInt, StrictStr

from antarest.core.exceptions import InvalidFieldForVersionError
from antarest.study.business.all_optional_meta import all_optional_model
from antarest.study.business.enum_ignore_case import EnumIgnoreCase
from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel, execute_or_add_commands
from antarest.study.model import STUDY_VERSION_8_8, Study
from antarest.study.storage.storage_service import StudyStorageService
from antarest.study.business.study_interface import StudyInterface
from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel
from antarest.study.model import STUDY_VERSION_8_8
from antarest.study.storage.variantstudy.model.command.update_config import UpdateConfig
from antarest.study.storage.variantstudy.model.command_context import CommandContext


class InitialReservoirLevel(EnumIgnoreCase):
Expand Down Expand Up @@ -216,14 +216,14 @@ def check_accuracy_on_correlation(cls, v: str) -> str:


class AdvancedParamsManager:
def __init__(self, storage_service: StudyStorageService) -> None:
self.storage_service = storage_service
def __init__(self, command_context: CommandContext) -> None:
self._command_context = command_context

def get_field_values(self, study: Study) -> AdvancedParamsFormFields:
def get_field_values(self, study: StudyInterface) -> AdvancedParamsFormFields:
"""
Get Advanced parameters values for the webapp form
"""
file_study = self.storage_service.get_storage(study).get_raw(study)
file_study = study.get_files()
general_data = file_study.tree.get(GENERAL_DATA_PATH.split("/"))
advanced_params = general_data.get("advanced parameters", {})
other_preferences = general_data.get("other preferences", {})
Expand All @@ -242,7 +242,7 @@ def get_value(field_info: FieldInfo) -> Any:

return AdvancedParamsFormFields.model_construct(**{name: get_value(info) for name, info in FIELDS_INFO.items()})

def set_field_values(self, study: Study, field_values: AdvancedParamsFormFields) -> None:
def set_field_values(self, study: StudyInterface, field_values: AdvancedParamsFormFields) -> None:
"""
Set Advanced parameters values from the webapp form
"""
Expand All @@ -256,19 +256,18 @@ def set_field_values(self, study: Study, field_values: AdvancedParamsFormFields)
if (
field_name == "unit_commitment_mode"
and value == UnitCommitmentMode.MILP
and StudyVersion.parse(study.version) < STUDY_VERSION_8_8
and study.version < STUDY_VERSION_8_8
):
raise InvalidFieldForVersionError("Unit commitment mode `MILP` only exists in v8.8+ studies")

commands.append(
UpdateConfig(
target=info["path"],
data=value,
command_context=self.storage_service.variant_study_service.command_factory.command_context,
command_context=self._command_context,
study_version=study.version,
)
)

if len(commands) > 0:
file_study = self.storage_service.get_storage(study).get_raw(study)
execute_or_add_commands(study, file_study, commands, self.storage_service)
study.add_commands(commands)
30 changes: 14 additions & 16 deletions antarest/study/business/allocation_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

from antarest.core.exceptions import AllocationDataNotFound, AreaNotFound
from antarest.study.business.model.area_model import AreaInfoDTO
from antarest.study.business.utils import FormFieldsBaseModel, execute_or_add_commands
from antarest.study.model import Study
from antarest.study.storage.storage_service import StudyStorageService
from antarest.study.business.study_interface import StudyInterface
from antarest.study.business.utils import FormFieldsBaseModel
from antarest.study.storage.variantstudy.model.command.update_config import UpdateConfig
from antarest.study.storage.variantstudy.model.command_context import CommandContext


class AllocationField(FormFieldsBaseModel):
Expand Down Expand Up @@ -114,10 +114,10 @@ class AllocationManager:
Manage hydraulic allocation coefficients.
"""

def __init__(self, storage_service: StudyStorageService) -> None:
self.storage_service = storage_service
def __init__(self, command_context: CommandContext) -> None:
self._command_context = command_context

def get_allocation_data(self, study: Study, area_id: str) -> Dict[str, float]:
def get_allocation_data(self, study: StudyInterface, area_id: str) -> Dict[str, float]:
"""
Get hydraulic allocation data.
Expand All @@ -133,7 +133,7 @@ def get_allocation_data(self, study: Study, area_id: str) -> Dict[str, float]:
"""
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression

file_study = self.storage_service.get_storage(study).get_raw(study)
file_study = study.get_files()
allocation_data = file_study.tree.get(f"input/hydro/allocation/{area_id}".split("/"), depth=2)

if not allocation_data:
Expand All @@ -142,7 +142,7 @@ def get_allocation_data(self, study: Study, area_id: str) -> Dict[str, float]:
return allocation_data.get("[allocation]", {}) # type: ignore

def get_allocation_form_fields(
self, all_areas: List[AreaInfoDTO], study: Study, area_id: str
self, all_areas: List[AreaInfoDTO], study: StudyInterface, area_id: str
) -> AllocationFormFields:
"""
Get hydraulic allocation coefficients.
Expand Down Expand Up @@ -172,7 +172,7 @@ def get_allocation_form_fields(
def set_allocation_form_fields(
self,
all_areas: List[AreaInfoDTO],
study: Study,
study: StudyInterface,
area_id: str,
data: AllocationFormFields,
) -> AllocationFormFields:
Expand All @@ -198,16 +198,14 @@ def set_allocation_form_fields(

filtered_allocations = [f for f in data.allocation if f.coefficient > 0 and f.area_id in areas_ids]

command_context = self.storage_service.variant_study_service.command_factory.command_context
file_study = self.storage_service.get_storage(study).get_raw(study)
command = UpdateConfig(
target=f"input/hydro/allocation/{area_id}/[allocation]",
data={f.area_id: f.coefficient for f in filtered_allocations},
command_context=command_context,
study_version=file_study.config.version,
command_context=self._command_context,
study_version=study.version,
)

execute_or_add_commands(study, file_study, [command], self.storage_service)
study.add_commands([command])

updated_allocations = self.get_allocation_data(study, area_id)

Expand All @@ -218,7 +216,7 @@ def set_allocation_form_fields(
]
)

def get_allocation_matrix(self, study: Study, all_areas: List[AreaInfoDTO]) -> AllocationMatrix:
def get_allocation_matrix(self, study: StudyInterface, all_areas: List[AreaInfoDTO]) -> AllocationMatrix:
"""
Get the hydraulic allocation matrix for all areas in the study.
Expand All @@ -233,7 +231,7 @@ def get_allocation_matrix(self, study: Study, all_areas: List[AreaInfoDTO]) -> A
AllocationDataNotFound: if the allocation data is not found.
"""

file_study = self.storage_service.get_storage(study).get_raw(study)
file_study = study.get_files()
allocation_cfg = file_study.tree.get(["input", "hydro", "allocation"], depth=3)

if not allocation_cfg:
Expand Down
Loading

0 comments on commit 0bffadb

Please sign in to comment.