Skip to content

Commit

Permalink
chore(binding-constraint): reduce code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
laurent-laporte-pro committed Oct 11, 2023
1 parent d277805 commit b41d957
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 93 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABCMeta
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import numpy as np
Expand All @@ -21,19 +22,43 @@
from antarest.study.storage.variantstudy.model.command.icommand import MATCH_SIGNATURE_SEPARATOR, ICommand
from antarest.study.storage.variantstudy.model.model import CommandDTO

__all__ = ("AbstractBindingConstraintCommand", "CreateBindingConstraint", "check_matrix_values")

MatrixType = List[List[MatrixData]]


class CreateBindingConstraint(ICommand):
"""
Command used to create a binding constraint.
def check_matrix_values(time_step: BindingConstraintFrequency, values: MatrixType) -> None:
"""
Check the binding constraint's matrix values for the specified time step.
command_name: CommandName = CommandName.CREATE_BINDING_CONSTRAINT
version: int = 1
Args:
time_step: The frequency of the binding constraint: "hourly", "daily" or "weekly".
values: The binding constraint's 2nd member matrix.
# Properties of the `CREATE_BINDING_CONSTRAINT` command:
name: str
Raises:
ValueError:
If the matrix shape does not match the expected shape for the given time step.
If the matrix values contain NaN (Not-a-Number).
"""
shapes = {
BindingConstraintFrequency.HOURLY: (8760, 3),
BindingConstraintFrequency.DAILY: (365, 3),
BindingConstraintFrequency.WEEKLY: (52, 3),
}
# Check the matrix values and create the corresponding matrix link
array = np.array(values, dtype=np.float64)
if array.shape != shapes[time_step]:
raise ValueError(f"Invalid matrix shape {array.shape}, expected {shapes[time_step]}")
if np.isnan(array).any():
raise ValueError("Matrix values cannot contain NaN")


class AbstractBindingConstraintCommand(ICommand, metaclass=ABCMeta):
"""
Abstract class for binding constraint commands.
"""

# todo: add the `name` attribute because it should also be updated
enabled: bool = True
time_step: BindingConstraintFrequency
operator: BindingConstraintOperator
Expand All @@ -43,6 +68,42 @@ class CreateBindingConstraint(ICommand):
filter_synthesis: Optional[str] = None
comments: Optional[str] = None

def to_dto(self) -> CommandDTO:
args = {
"enabled": self.enabled,
"time_step": self.time_step.value,
"operator": self.operator.value,
"coeffs": self.coeffs,
"comments": self.comments,
"filter_year_by_year": self.filter_year_by_year,
"filter_synthesis": self.filter_synthesis,
}
if self.values is not None:
args["values"] = strip_matrix_protocol(self.values)
return CommandDTO(
action=self.command_name.value,
args=args,
)

def get_inner_matrices(self) -> List[str]:
if self.values is not None:
if not isinstance(self.values, str): # pragma: no cover
raise TypeError(repr(self.values))
return [strip_matrix_protocol(self.values)]
return []


class CreateBindingConstraint(AbstractBindingConstraintCommand):
"""
Command used to create a binding constraint.
"""

command_name: CommandName = CommandName.CREATE_BINDING_CONSTRAINT
version: int = 1

# Properties of the `CREATE_BINDING_CONSTRAINT` command:
name: str

@validator("values", always=True)
def validate_series(
cls,
Expand All @@ -65,18 +126,7 @@ def validate_series(
# Check the matrix link
return validate_matrix(v, values)
if isinstance(v, list):
shapes = {
BindingConstraintFrequency.HOURLY: (8760, 3),
BindingConstraintFrequency.DAILY: (365, 3),
BindingConstraintFrequency.WEEKLY: (52, 3),
}
# Check the matrix values and create the corresponding matrix link
array = np.array(v, dtype=np.float64)
if array.shape != shapes[time_step]:
raise ValueError(f"Invalid matrix shape {array.shape}, expected {shapes[time_step]}")
if np.isnan(array).any():
raise ValueError("Matrix values cannot contain NaN")
v = cast(MatrixType, array.tolist())
check_matrix_values(time_step, v)
return validate_matrix(v, values)
# Invalid datatype
# pragma: no cover
Expand Down Expand Up @@ -108,20 +158,9 @@ def _apply(self, study_data: FileStudy) -> CommandOutput:
)

def to_dto(self) -> CommandDTO:
return CommandDTO(
action=CommandName.CREATE_BINDING_CONSTRAINT.value,
args={
"name": self.name,
"enabled": self.enabled,
"time_step": self.time_step.value,
"operator": self.operator.value,
"coeffs": self.coeffs,
"values": strip_matrix_protocol(self.values),
"comments": self.comments,
"filter_year_by_year": self.filter_year_by_year,
"filter_synthesis": self.filter_synthesis,
},
)
dto = super().to_dto()
dto.args["name"] = self.name # type: ignore
return dto

def match_signature(self) -> str:
return str(self.command_name.value + MATCH_SIGNATURE_SEPARATOR + self.name)
Expand Down Expand Up @@ -161,10 +200,3 @@ def _create_diff(self, other: "ICommand") -> List["ICommand"]:
command_context=other.command_context,
)
]

def get_inner_matrices(self) -> List[str]:
if self.values is not None:
if not isinstance(self.values, str): # pragma: no cover
raise TypeError(repr(self.values))
return [strip_matrix_protocol(self.values)]
return []
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
from pydantic import Field, validator
from pydantic import validator

from antarest.core.model import JSON
from antarest.core.utils.utils import assert_this
from antarest.matrixstore.model import MatrixData
from antarest.study.storage.rawstudy.model.filesystem.config.binding_constraint import BindingConstraintFrequency
from antarest.study.storage.rawstudy.model.filesystem.config.model import FileStudyTreeConfig
from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy
from antarest.study.storage.variantstudy.business.utils import strip_matrix_protocol, validate_matrix
from antarest.study.storage.variantstudy.business.utils import validate_matrix
from antarest.study.storage.variantstudy.business.utils_binding_constraint import apply_binding_constraint
from antarest.study.storage.variantstudy.model.command.common import (
BindingConstraintOperator,
CommandName,
CommandOutput,
from antarest.study.storage.variantstudy.model.command.common import CommandName, CommandOutput
from antarest.study.storage.variantstudy.model.command.create_binding_constraint import (
AbstractBindingConstraintCommand,
check_matrix_values,
)
from antarest.study.storage.variantstudy.model.command.icommand import MATCH_SIGNATURE_SEPARATOR, ICommand
from antarest.study.storage.variantstudy.model.model import CommandDTO

__all__ = ("UpdateBindingConstraint",)

MatrixType = List[List[MatrixData]]


class UpdateBindingConstraint(ICommand):
class UpdateBindingConstraint(AbstractBindingConstraintCommand):
"""
Command used to update a binding constraint.
"""
Expand All @@ -32,14 +31,6 @@ class UpdateBindingConstraint(ICommand):

# Properties of the `UPDATE_BINDING_CONSTRAINT` command:
id: str
enabled: bool = True
time_step: BindingConstraintFrequency
operator: BindingConstraintOperator
coeffs: Dict[str, List[float]]
values: Optional[Union[MatrixType, str]] = Field(None, description="2nd member matrix")
filter_year_by_year: Optional[str] = None
filter_synthesis: Optional[str] = None
comments: Optional[str] = None

@validator("values", always=True)
def validate_series(
Expand All @@ -55,18 +46,7 @@ def validate_series(
# Check the matrix link
return validate_matrix(v, values)
if isinstance(v, list):
shapes = {
BindingConstraintFrequency.HOURLY: (8760, 3),
BindingConstraintFrequency.DAILY: (365, 3),
BindingConstraintFrequency.WEEKLY: (52, 3),
}
# Check the matrix values and create the corresponding matrix link
array = np.array(v, dtype=np.float64)
if array.shape != shapes[time_step]:
raise ValueError(f"Invalid matrix shape {array.shape}, expected {shapes[time_step]}")
if np.isnan(array).any():
raise ValueError("Matrix values cannot contain NaN")
v = cast(MatrixType, array.tolist())
check_matrix_values(time_step, v)
return validate_matrix(v, values)
# Invalid datatype
# pragma: no cover
Expand Down Expand Up @@ -108,22 +88,9 @@ def _apply(self, study_data: FileStudy) -> CommandOutput:
)

def to_dto(self) -> CommandDTO:
args = {
"id": self.id,
"enabled": self.enabled,
"time_step": self.time_step.value,
"operator": self.operator.value,
"coeffs": self.coeffs,
"comments": self.comments,
"filter_year_by_year": self.filter_year_by_year,
"filter_synthesis": self.filter_synthesis,
}
if self.values is not None:
args["values"] = strip_matrix_protocol(self.values)
return CommandDTO(
action=CommandName.UPDATE_BINDING_CONSTRAINT.value,
args=args,
)
dto = super().to_dto()
dto.args["id"] = self.id # type: ignore
return dto

def match_signature(self) -> str:
return str(self.command_name.value + MATCH_SIGNATURE_SEPARATOR + self.id)
Expand All @@ -146,9 +113,3 @@ def match(self, other: ICommand, equal: bool = False) -> bool:

def _create_diff(self, other: "ICommand") -> List["ICommand"]:
return [other]

def get_inner_matrices(self) -> List[str]:
if self.values is not None:
assert_this(isinstance(self.values, str))
return [strip_matrix_protocol(self.values)]
return []

0 comments on commit b41d957

Please sign in to comment.