Skip to content

Commit

Permalink
merge with main
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBelthle committed Feb 18, 2025
2 parents 7eeef2d + 0bffadb commit 3b44a4c
Show file tree
Hide file tree
Showing 186 changed files with 4,486 additions and 3,569 deletions.
15 changes: 15 additions & 0 deletions antarest/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def from_dict(cls, data: JSON) -> "StorageConfig":
if "workspaces" in data
else defaults.workspaces
)

cls._validate_workspaces(data, workspaces)
return cls(
matrixstore=Path(data["matrixstore"]) if "matrixstore" in data else defaults.matrixstore,
archive_dir=Path(data["archive_dir"]) if "archive_dir" in data else defaults.archive_dir,
Expand All @@ -225,6 +227,19 @@ def from_dict(cls, data: JSON) -> "StorageConfig":
matrixstore_format=InternalMatrixFormat(data.get("matrixstore_format", defaults.matrixstore_format)),
)

@classmethod
def _validate_workspaces(cls, config_as_json: JSON, workspaces: Dict[str, WorkspaceConfig]) -> None:
"""
Validate that no two workspaces have overlapping paths.
"""
workspace_name_by_path = [(config.path, name) for name, config in workspaces.items()]
for path, name in workspace_name_by_path:
for path2, name2 in workspace_name_by_path:
if name != name2 and path.is_relative_to(path2):
raise ValueError(
f"Overlapping workspace paths found: '{name}' and '{name2}' '{path}' is relative to '{path2}' "
)


@dataclass(frozen=True)
class NbCoresConfig:
Expand Down
4 changes: 4 additions & 0 deletions antarest/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import typing_extensions as te
from pydantic import StringConstraints

from antarest.core.serde import AntaresBaseModel

if TYPE_CHECKING:
Expand All @@ -22,6 +25,7 @@
JSON = Dict[str, Any]
ELEMENT = Union[str, int, float, bool, bytes]
SUB_JSON = Union[ELEMENT, JSON, List[Any], None]
LowerCaseStr = te.Annotated[str, StringConstraints(to_lower=True)]


class PublicMode(enum.StrEnum):
Expand Down
4 changes: 3 additions & 1 deletion antarest/eventbus/business/redis_eventbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# This file is part of the Antares project.

import logging
import pathlib
from typing import List, Optional, cast

from redis.client import Redis
Expand Down Expand Up @@ -42,7 +43,8 @@ def queue_event(self, event: Event, queue: str) -> None:
def pull_queue(self, queue: str) -> Optional[Event]:
event = self.redis.lpop(queue)
if event:
return cast(Optional[Event], Event.parse_raw(event))
event_string = pathlib.Path(event).read_text()
return cast(Optional[Event], Event.model_validate_json(event_string))
return None

@override
Expand Down
3 changes: 2 additions & 1 deletion antarest/eventbus/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import dataclasses
import logging
import pathlib
from enum import StrEnum
from http import HTTPStatus
from typing import List, Optional
Expand Down Expand Up @@ -75,7 +76,7 @@ def process_message(self, message: str, websocket: WebSocket) -> None:
if not connection:
return

ws_message = WebsocketMessage.parse_raw(message)
ws_message = WebsocketMessage.model_validate_json(message)
if ws_message.action == WebsocketMessageAction.SUBSCRIBE:
if ws_message.payload not in connection.channel_subscriptions:
connection.channel_subscriptions.append(ws_message.payload)
Expand Down
2 changes: 1 addition & 1 deletion antarest/launcher/extensions/adequacy_patch/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from antarest.core.utils.utils import assert_this
from antarest.launcher.extensions.interface import ILauncherExtension
from antarest.study.service import StudyService
from antarest.study.storage.rawstudy.model.filesystem.config.model import transform_name_to_id
from antarest.study.storage.rawstudy.model.filesystem.config.identifier import transform_name_to_id
from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy

logger = logging.getLogger(__name__)
Expand Down
65 changes: 65 additions & 0 deletions antarest/matrixstore/in_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2025, RTE (https://www.rte-france.com)
#
# See AUTHORS.txt
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
#
# SPDX-License-Identifier: MPL-2.0
#
# This file is part of the Antares project.

import hashlib
import time
from typing import Dict, List, Optional

import numpy as np
import numpy.typing as npt
from typing_extensions import override

from antarest.matrixstore.model import MatrixData, MatrixDTO
from antarest.matrixstore.service import ISimpleMatrixService


class InMemorySimpleMatrixService(ISimpleMatrixService):
"""
In memory implementation of matrix service, for unit testing purposes.
"""

def __init__(self) -> None:
self._content: Dict[str, MatrixDTO] = {}

def _make_dto(self, id: str, matrix: npt.NDArray[np.float64]) -> MatrixDTO:
matrix = matrix.reshape((1, 0)) if matrix.size == 0 else matrix
data = matrix.tolist()
index = [str(i) for i in range(matrix.shape[0])]
columns = [str(i) for i in range(matrix.shape[1])]
return MatrixDTO(
data=data,
index=index,
columns=columns,
id=id,
created_at=int(time.time()),
width=len(columns),
height=len(index),
)

@override
def create(self, data: List[List[MatrixData]] | npt.NDArray[np.float64]) -> str:
matrix = data if isinstance(data, np.ndarray) else np.array(data, dtype=np.float64)
matrix_hash = hashlib.sha256(matrix.data).hexdigest()
self._content[matrix_hash] = self._make_dto(matrix_hash, matrix)
return matrix_hash

@override
def get(self, matrix_id: str) -> Optional[MatrixDTO]:
return self._content.get(matrix_id, None)

@override
def exists(self, matrix_id: str) -> bool:
return matrix_id in self._content

@override
def delete(self, matrix_id: str) -> None:
del self._content[matrix_id]
4 changes: 2 additions & 2 deletions antarest/matrixstore/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ def __eq__(self, other: Any) -> bool:
class MatrixDTO(AntaresBaseModel):
width: int
height: int
index: List[str]
columns: List[str]
index: List[int | str]
columns: List[int | str]
data: List[List[MatrixData]]
created_at: int = 0
id: str = ""
Expand Down
6 changes: 3 additions & 3 deletions antarest/matrixstore/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def get(self, matrix_hash: str) -> MatrixContent:
matrix = storage_format.load_matrix(matrix_path)
matrix = matrix.reshape((1, 0)) if matrix.size == 0 else matrix
data = matrix.tolist()
index = list(range(matrix.shape[0]))
columns = list(range(matrix.shape[1]))
return MatrixContent.construct(data=data, columns=columns, index=index)
index: List[int | str] = list(range(matrix.shape[0]))
columns: List[int | str] = list(range(matrix.shape[1]))
return MatrixContent.model_construct(data=data, columns=columns, index=index)

def exists(self, matrix_hash: str) -> bool:
"""
Expand Down
4 changes: 2 additions & 2 deletions antarest/matrixstore/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def create(self, data: List[List[MatrixData]] | npt.NDArray[np.float64]) -> str:
@override
def get(self, matrix_id: str) -> MatrixDTO:
data = self.matrix_content_repository.get(matrix_id)
return MatrixDTO.construct(
return MatrixDTO.model_construct(
id=matrix_id,
width=len(data.columns),
height=len(data.index),
Expand Down Expand Up @@ -394,7 +394,7 @@ def get(self, matrix_id: str) -> Optional[MatrixDTO]:
if matrix is None:
return None
content = self.matrix_content_repository.get(matrix_id)
return MatrixDTO.construct(
return MatrixDTO.model_construct(
id=matrix.id,
width=matrix.width,
height=matrix.height,
Expand Down
6 changes: 3 additions & 3 deletions antarest/matrixstore/uri_resolver_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@

import pandas as pd

from antarest.core.model import SUB_JSON
from antarest.core.model import JSON
from antarest.matrixstore.service import ISimpleMatrixService


class UriResolverService:
def __init__(self, matrix_service: ISimpleMatrixService):
self.matrix_service = matrix_service

def resolve(self, uri: str, formatted: bool = True) -> SUB_JSON:
def resolve(self, uri: str, formatted: bool = True) -> JSON | str | None:
res = UriResolverService._extract_uri_components(uri)
if res:
protocol, uuid = res
Expand All @@ -49,7 +49,7 @@ def extract_id(uri: str) -> Optional[str]:
res = UriResolverService._extract_uri_components(uri)
return res[1] if res else None

def _resolve_matrix(self, id: str, formatted: bool = True) -> SUB_JSON:
def _resolve_matrix(self, id: str, formatted: bool = True) -> JSON | str:
data = self.matrix_service.get(id)
if not data:
raise ValueError(f"id matrix {id} not found")
Expand Down
24 changes: 12 additions & 12 deletions antarest/study/business/adequacy_patch_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

from antarest.study.business.all_optional_meta import all_optional_model
from antarest.study.business.enum_ignore_case import EnumIgnoreCase
from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel, execute_or_add_commands
from antarest.study.model import STUDY_VERSION_8_3, STUDY_VERSION_8_5, Study
from antarest.study.storage.storage_service import StudyStorageService
from antarest.study.business.study_interface import StudyInterface
from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel
from antarest.study.model import STUDY_VERSION_8_3, STUDY_VERSION_8_5
from antarest.study.storage.variantstudy.model.command.update_config import UpdateConfig
from antarest.study.storage.variantstudy.model.command_context import CommandContext


class PriceTakingOrder(EnumIgnoreCase):
Expand Down Expand Up @@ -98,28 +99,28 @@ class AdequacyPatchFormFields(FormFieldsBaseModel):


class AdequacyPatchManager:
def __init__(self, storage_service: StudyStorageService) -> None:
self.storage_service = storage_service
def __init__(self, command_context: CommandContext) -> None:
self._command_context = command_context

def get_field_values(self, study: Study) -> AdequacyPatchFormFields:
def get_field_values(self, study: StudyInterface) -> AdequacyPatchFormFields:
"""
Get adequacy patch field values for the webapp form
"""
file_study = self.storage_service.get_storage(study).get_raw(study)
file_study = study.get_files()
general_data = file_study.tree.get(GENERAL_DATA_PATH.split("/"))
parent = general_data.get("adequacy patch", {})

def get_value(field_info: FieldInfo) -> Any:
path = field_info["path"]
start_version = field_info.get("start_version", -1)
target_name = path.split("/")[-1]
is_in_version = file_study.config.version >= start_version
is_in_version = study.version >= start_version

return parent.get(target_name, field_info["default_value"]) if is_in_version else None

return AdequacyPatchFormFields.model_construct(**{name: get_value(info) for name, info in FIELDS_INFO.items()})

def set_field_values(self, study: Study, field_values: AdequacyPatchFormFields) -> None:
def set_field_values(self, study: StudyInterface, field_values: AdequacyPatchFormFields) -> None:
"""
Set adequacy patch config from the webapp form
"""
Expand All @@ -133,11 +134,10 @@ def set_field_values(self, study: Study, field_values: AdequacyPatchFormFields)
UpdateConfig(
target=info["path"],
data=value,
command_context=self.storage_service.variant_study_service.command_factory.command_context,
command_context=self._command_context,
study_version=study.version,
)
)

if commands:
file_study = self.storage_service.get_storage(study).get_raw(study)
execute_or_add_commands(study, file_study, commands, self.storage_service)
study.add_commands(commands)
25 changes: 12 additions & 13 deletions antarest/study/business/advanced_parameters_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@

from typing import Any, Dict, List

from antares.study.version import StudyVersion
from pydantic import field_validator
from pydantic.types import StrictInt, StrictStr

from antarest.core.exceptions import InvalidFieldForVersionError
from antarest.study.business.all_optional_meta import all_optional_model
from antarest.study.business.enum_ignore_case import EnumIgnoreCase
from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel, execute_or_add_commands
from antarest.study.model import STUDY_VERSION_8_8, Study
from antarest.study.storage.storage_service import StudyStorageService
from antarest.study.business.study_interface import StudyInterface
from antarest.study.business.utils import GENERAL_DATA_PATH, FieldInfo, FormFieldsBaseModel
from antarest.study.model import STUDY_VERSION_8_8
from antarest.study.storage.variantstudy.model.command.update_config import UpdateConfig
from antarest.study.storage.variantstudy.model.command_context import CommandContext


class InitialReservoirLevel(EnumIgnoreCase):
Expand Down Expand Up @@ -216,14 +216,14 @@ def check_accuracy_on_correlation(cls, v: str) -> str:


class AdvancedParamsManager:
def __init__(self, storage_service: StudyStorageService) -> None:
self.storage_service = storage_service
def __init__(self, command_context: CommandContext) -> None:
self._command_context = command_context

def get_field_values(self, study: Study) -> AdvancedParamsFormFields:
def get_field_values(self, study: StudyInterface) -> AdvancedParamsFormFields:
"""
Get Advanced parameters values for the webapp form
"""
file_study = self.storage_service.get_storage(study).get_raw(study)
file_study = study.get_files()
general_data = file_study.tree.get(GENERAL_DATA_PATH.split("/"))
advanced_params = general_data.get("advanced parameters", {})
other_preferences = general_data.get("other preferences", {})
Expand All @@ -242,7 +242,7 @@ def get_value(field_info: FieldInfo) -> Any:

return AdvancedParamsFormFields.model_construct(**{name: get_value(info) for name, info in FIELDS_INFO.items()})

def set_field_values(self, study: Study, field_values: AdvancedParamsFormFields) -> None:
def set_field_values(self, study: StudyInterface, field_values: AdvancedParamsFormFields) -> None:
"""
Set Advanced parameters values from the webapp form
"""
Expand All @@ -256,19 +256,18 @@ def set_field_values(self, study: Study, field_values: AdvancedParamsFormFields)
if (
field_name == "unit_commitment_mode"
and value == UnitCommitmentMode.MILP
and StudyVersion.parse(study.version) < STUDY_VERSION_8_8
and study.version < STUDY_VERSION_8_8
):
raise InvalidFieldForVersionError("Unit commitment mode `MILP` only exists in v8.8+ studies")

commands.append(
UpdateConfig(
target=info["path"],
data=value,
command_context=self.storage_service.variant_study_service.command_factory.command_context,
command_context=self._command_context,
study_version=study.version,
)
)

if len(commands) > 0:
file_study = self.storage_service.get_storage(study).get_raw(study)
execute_or_add_commands(study, file_study, commands, self.storage_service)
study.add_commands(commands)
Loading

0 comments on commit 3b44a4c

Please sign in to comment.