Skip to content

Commit

Permalink
refactor: introduce _MATRIX_NAMES to enumerate all matrix names
Browse files Browse the repository at this point in the history
  • Loading branch information
laurent-laporte-pro committed Jul 12, 2023
1 parent 77bb8ff commit 2c4fab7
Showing 1 changed file with 32 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,16 @@
from pydantic import Field, validator, Extra
from pydantic.fields import ModelField

# minimum required version.
# noinspection SpellCheckingInspection
_MATRIX_NAMES = (
"pmax_injection",
"pmax_withdrawal",
"lower_rule_curve",
"upper_rule_curve",
"inflows",
)

# Minimum required version.
REQUIRED_VERSION = 860

MatrixType = List[List[MatrixData]]
Expand Down Expand Up @@ -88,14 +97,7 @@ def storage_name(self) -> str:
"""The label representing the name of the storage for the user."""
return self.parameters.name

@validator(
"pmax_injection",
"pmax_withdrawal",
"lower_rule_curve",
"upper_rule_curve",
"inflows",
always=True,
)
@validator(*_MATRIX_NAMES, always=True)
def register_matrix(
cls,
v: Optional[Union[MatrixType, str]],
Expand Down Expand Up @@ -134,27 +136,27 @@ def register_matrix(
method = getattr(constants, method_name)
return cast(str, method())
if isinstance(v, str):
# check the matrix link
# Check the matrix link
return validate_matrix(v, values)
if isinstance(v, list):
# check the matrix values and create the corresponding matrix link
# Check the matrix values and create the corresponding matrix link
array = np.array(v, dtype=np.float64)
if array.shape != (8760, 1):
raise ValueError(
f"Invalid matrix shape {array.shape}, expected (8760, 1)"
)
if np.isnan(array).any():
raise ValueError("Matrix values cannot contain NaN")
if field.name in {
"pmax_injection",
"pmax_withdrawal",
"lower_rule_curve",
"upper_rule_curve",
} and (np.any(array < 0) or np.any(array > 1)):
# All matrices except "inflows" are constrained between 0 and 1
constrained = set(_MATRIX_NAMES) - {"inflows"}
if field.name in constrained and (
np.any(array < 0) or np.any(array > 1)
):
raise ValueError("Matrix values should be between 0 and 1")
v = cast(MatrixType, array.tolist())
return validate_matrix(v, values)
# invalid datatype (not implemented?)
# Invalid datatype
# pragma: no cover
raise TypeError(repr(v))

def _apply_config(
Expand Down Expand Up @@ -254,11 +256,8 @@ def _apply(self, study_data: FileStudy) -> CommandOutput:
"series": {
self.area_id: {
self.storage_id: {
"pmax_injection": self.pmax_injection,
"pmax_withdrawal": self.pmax_withdrawal,
"lower_rule_curve": self.lower_rule_curve,
"upper_rule_curve": self.upper_rule_curve,
"inflows": self.inflows,
attr: getattr(self, attr)
for attr in _MATRIX_NAMES
}
}
},
Expand All @@ -277,21 +276,18 @@ def to_dto(self) -> CommandDTO:
Returns:
The DTO object representing the current command.
"""
# fmt: off
parameters = json.loads(self.parameters.json(by_alias=True))
return CommandDTO(
action=self.command_name.value,
args={
"area_id": self.area_id,
"parameters": parameters,
"pmax_injection": strip_matrix_protocol(self.pmax_injection),
"pmax_withdrawal": strip_matrix_protocol(self.pmax_withdrawal),
"lower_rule_curve": strip_matrix_protocol(self.lower_rule_curve),
"upper_rule_curve": strip_matrix_protocol(self.upper_rule_curve),
"inflows": strip_matrix_protocol(self.inflows),
**{
attr: strip_matrix_protocol(getattr(self, attr))
for attr in _MATRIX_NAMES
},
},
)
# fmt: on

def match_signature(self) -> str:
"""Returns the command signature."""
Expand All @@ -317,7 +313,7 @@ def match(self, other: "ICommand", equal: bool = False) -> bool:
if not isinstance(other, CreateSTStorage):
return False
if equal:
# deep comparison
# Deep comparison
return self.__eq__(other)
else:
return (
Expand All @@ -337,33 +333,24 @@ def _create_diff(self, other: "ICommand") -> List["ICommand"]:
A list of commands representing the differences between
the two `ICommand` objects.
"""
other = cast(CreateSTStorage, other)
from antarest.study.storage.variantstudy.model.command.replace_matrix import (
ReplaceMatrix,
)
from antarest.study.storage.variantstudy.model.command.update_config import (
UpdateConfig,
)

# fixme: drop this mapping
attrs = {
"pmax_injection": "pmax_injection",
"pmax_withdrawal": "pmax_withdrawal",
"lower_rule_curve": "lower_rule_curve",
"upper_rule_curve": "upper_rule_curve",
"inflows": "inflows",
}
other = cast(CreateSTStorage, other)
commands: List[ICommand] = [
ReplaceMatrix(
target=f"input/st-storage/series/{self.area_id}/{self.storage_id}/{ini_name}",
target=f"input/st-storage/series/{self.area_id}/{self.storage_id}/{attr}",
matrix=strip_matrix_protocol(getattr(other, attr)),
command_context=self.command_context,
)
for ini_name, attr in attrs.items()
for attr in _MATRIX_NAMES
if getattr(self, attr) != getattr(other, attr)
]
if self.parameters != other.parameters:
# Exclude the `id` because it is read-only, and they can't be modified (calculated)
data: Dict[str, Any] = json.loads(
other.parameters.json(by_alias=True)
)
Expand All @@ -380,14 +367,8 @@ def get_inner_matrices(self) -> List[str]:
"""
Retrieves the list of matrix IDs.
"""
attrs = [
"pmax_injection",
"pmax_withdrawal",
"lower_rule_curve",
"upper_rule_curve",
"inflows",
]
matrices: List[str] = [
strip_matrix_protocol(getattr(self, attr)) for attr in attrs
strip_matrix_protocol(getattr(self, attr))
for attr in _MATRIX_NAMES
]
return matrices

0 comments on commit 2c4fab7

Please sign in to comment.