Skip to content

Commit

Permalink
D. Jabs:
Browse files Browse the repository at this point in the history
- Updated type hints and check functions for continuous.py
  • Loading branch information
Dennis Jabs committed Nov 16, 2023
1 parent 35be973 commit a1b8025
Showing 1 changed file with 59 additions and 73 deletions.
132 changes: 59 additions & 73 deletions PyHyperparameterSpace/hp/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ def _is_legal_bounds(self, bounds: Union[tuple[float, float], tuple[int, int]]):
pass

@abstractmethod
def _check_distribution(self, distribution: Distribution) -> Distribution:
def _check_distribution(self, distribution: Union[Distribution, None]) -> Distribution:
"""
Checks if the distribution is legal. A distribution is called legal, if the class of the distribution can be
used for the given hyperparameter class.
Args:
distribution (Distribution):
distribution (Union[Distribution, None]):
Distribution to check
Returns:
Expand Down Expand Up @@ -198,13 +198,10 @@ def __init__(
self,
name: str,
bounds: Union[tuple[float, float], tuple[int, int]],
default: Union[int, float, list, np.ndarray] = None,
default: Any = None,
shape: Union[tuple[int, ...], None] = None,
distribution: Distribution = None,
distribution: Union[Distribution, None] = None,
):
if isinstance(default, list):
default = np.array(default, dtype=float)

super().__init__(name=name, shape=shape, bounds=bounds, default=default, distribution=distribution)

def _check_bounds(self, bounds: Union[tuple[float, float], tuple[int, int]]) \
Expand All @@ -216,31 +213,24 @@ def _check_bounds(self, bounds: Union[tuple[float, float], tuple[int, int]]) \
f"Illegal bounds {bounds}. The argument should have the format (lower, upper), where lower < upper!")

def _is_legal_bounds(self, bounds: Union[tuple[float, float], tuple[int, int]]):
if isinstance(bounds, tuple) and len(bounds) == 2 and \
all(isinstance(b, (float, int, np.int_, np.float_)) for b in bounds) and bounds[0] < bounds[1]:
return True
else:
return False
if isinstance(bounds, tuple) and len(bounds) == 2:
return all(isinstance(b, (float, int, np.int_, np.float_)) for b in bounds) and bounds[0] < bounds[1]
return False

def _check_default(self, default: Union[int, float, np.ndarray]) -> Union[int, float, np.ndarray]:
def _check_default(self, default: Any) -> Any:
if default is None and self._shape == (1,):
# Case: default is not given and shape refers to single dimensional values
return (self.lb + self.ub) / 2
elif default is None and self._shape is not None and all(isinstance(s, (int, np.int_)) for s in self._shape):
# Case: default is not given and shape refers to multi dimensional values
return np.full(shape=self._shape, fill_value=((self.lb + self.ub) / 2))
elif self._is_legal_default(default):
elif default is not None and self._is_legal_default(default):
return default
else:
raise Exception(f"Illegal default {default}. The argument should be in between the bounds (lower, upper)!")

def _is_legal_default(self, default: Any) -> bool:
if not isinstance(default, (float, int, np.int_, np.float_)) and \
not (isinstance(default, np.ndarray) and np.issubdtype(default.dtype, np.floating)) and \
not (isinstance(default, np.ndarray) and np.issubdtype(default.dtype, np.integer)):
# Case: default is not in the right format
return False
if isinstance(default, (float, int, np.int_, np.float_)):
if isinstance(default, (float, int, np.float_, np.int_)):
# Case: default is single dimensional
return self.lb <= default < self.ub
elif isinstance(default, np.ndarray):
Expand All @@ -249,30 +239,29 @@ def _is_legal_default(self, default: Any) -> bool:
return False

def _check_shape(self, shape: Union[tuple[int, ...], None]) -> tuple[int, ...]:
if shape is None and isinstance(self._default, (float, int, np.int_, np.float_)):
if shape is None and isinstance(self._default, (float, int, np.float_, np.int_)):
# Case: shape is not given and default is single dimensional
return 1,
elif shape is None and isinstance(self._default, np.ndarray):
# Case: shape is not given and default is multidimensional
return self._default.shape
elif self._is_legal_shape(shape):
elif shape is not None and self._is_legal_shape(shape):
# Case: shape is given and legal
return shape
else:
# Case: shape is illegal
raise Exception(f"Illegal shape {shape}. The argument should have the right format (dim1, ...)!")

def _is_legal_shape(self, shape: tuple[int, ...]) -> bool:
if shape == (1,) and isinstance(self._default, (float, int, np.int_, np.float_)):
# Case: shape and default refers to single dimensional
return True
elif isinstance(shape, tuple) and all(isinstance(s, int) for s in shape) and \
isinstance(self._default, np.ndarray) and shape == self._default.shape:
# Case: shape and default refers to multidimensional
return True
if shape == (1,):
# Case: shape refers to single dimensional
return isinstance(self._default, (float, int, np.int_, np.float_))
elif isinstance(shape, tuple) and all(isinstance(s, int) for s in shape):
# Case: shape refers to multidimensional
return isinstance(self._default, np.ndarray) and shape == self._default.shape
return False

def _check_distribution(self, distribution: Union[Distribution, None]) -> Union[Distribution, None]:
def _check_distribution(self, distribution: Union[Distribution, None]) -> Distribution:
if distribution is None:
return Uniform(lb=self.lb, ub=self.ub)
elif self._is_legal_distribution(distribution):
Expand Down Expand Up @@ -350,33 +339,29 @@ def sample(self, random: np.random.RandomState, size: Union[int, None] = None) -
raise Exception(f"Unknown Distribution {self._distribution}!")

def valid_configuration(self, value: Any) -> bool:
if isinstance(value, (list, np.ndarray)):
if isinstance(value, np.ndarray):
# Case: Value is multidimensional
value = np.array(value)
return np.all((self.lb <= value) & (value < self.ub)) and self._shape == value.shape
elif isinstance(value, (int, float)):
elif isinstance(value, (float, int, np.float_, np.int_)):
# Case: value is single dimensional
return self.lb <= value < self.ub
return False

def adjust_configuration(self, value: Any) -> Any:
if isinstance(value, (list, np.ndarray)) and \
if isinstance(value, np.ndarray) and \
(np.issubdtype(value.dtype, np.float_) or np.issubdtype(value.dtype, np.int_)):
# Case: value is multidimensional
value = np.array(value)

# Do not exceed lower, upper bound
value[value < self.lb] = self.lb
value[value >= self.ub] = self.ub - 1e-10

return value
elif isinstance(value, (int, float, np.int_, np.float_)):
elif isinstance(value, (float, int, np.float_, np.int_)):
# Case: value is single dimensional
# Do not exceed lower, upper bound
if value < self.lb:
value = self.lb
elif value >= self.ub:
value = self.ub - 1e-10

return value
else:
# Case: value is illegal
Expand All @@ -392,8 +377,13 @@ def __hash__(self) -> int:
return hash(self.__repr__())

def __repr__(self) -> str:
text = f"Float({self._name}, bounds={self._bounds}, default={self._default}, shape={self._shape}, distribution={self._distribution})"
return text
header = f"Float({self._name}, "
bounds = f"bounds={self._bounds}, "
default = f"default={self._default}, "
shape = f"shape={self._shape}, "
distribution = f"distribution={self._distribution}"
end = ")"
return "".join([header, bounds, default, shape, distribution, end])


class Integer(Continuous):
Expand All @@ -410,7 +400,7 @@ class Integer(Continuous):
default (Any):
Default value of the hyperparameter
shape (Union[int, tuple[int, ...], None]):
shape (Union[tuple[int, ...], None]):
Shape of the hyperparameter
distribution (Union[Distribution, None]):
Expand All @@ -421,50 +411,45 @@ def __init__(
self,
name: str,
bounds: Union[tuple[int, int]],
default: Union[int, list, np.ndarray, None] = None,
default: Any = None,
shape: Union[tuple[int, ...], None] = None,
distribution: Distribution = None,
distribution: Union[Distribution, None] = None,
):
super().__init__(name=name, shape=shape, bounds=bounds, default=default, distribution=distribution)

def _check_bounds(self, bounds: tuple[int, int]) -> tuple[int, int]:
if self._is_legal_bounds(bounds):
return bounds
else:
raise Exception(f"Illegal bounds {bounds}!")
raise Exception(f"Illegal bounds {bounds}. The argument should have the right format (lower, upper), where lower < upper!")

def _is_legal_bounds(self, bounds: tuple[int, int]) -> bool:
if isinstance(bounds, tuple) and len(bounds) == 2 and \
all(isinstance(b, (int, np.int_)) for b in bounds) and bounds[0] < bounds[1]:
return True
else:
return False
if isinstance(bounds, tuple) and len(bounds) == 2:
return all(isinstance(b, (int, np.int_)) for b in bounds) and bounds[0] < bounds[1]
return False

def _check_default(self, default: Union[int, np.ndarray]) -> Union[int, np.ndarray]:
def _check_default(self, default: Any) -> Any:
if default is None and self._shape == (1,):
# Case: default is not given and shape refers to single dimensional values
return int((self.lb + self.ub) / 2)
elif default is None and self._shape is not None and all(isinstance(s, (int, np.int_)) for s in self._shape):
# Case: default is not given and shape refers to multidimensional values
return np.full(shape=self._shape, fill_value=int((self.lb + self.ub) / 2))
elif self._is_legal_default(default):
elif default is not None and self._is_legal_default(default):
# Case: default value is legal
return default
else:
# Case: default value is illegal
raise Exception(f"Illegal default {default}. The argument should be in between the bounds (lower, upper)!")

def _is_legal_default(self, default: Union[int, np.ndarray]) -> bool:
if not isinstance(default, int) and \
not (isinstance(default, np.ndarray) and np.issubdtype(default.dtype, np.integer)):
# Case: default is not in the right format!
return False
def _is_legal_default(self, default: Any) -> bool:
if isinstance(default, (int, np.int_)):
# Case: default is single dimensional
return self.lb <= default <= self.ub
else:
# Case: default is multidimensional
return np.all((default >= self.lb) & (default <= self.ub))
return False

def _check_shape(self, shape: Union[tuple[int, ...], None]) -> tuple[int, ...]:
if shape is None and isinstance(self._default, (int, np.int_)):
Expand All @@ -473,20 +458,20 @@ def _check_shape(self, shape: Union[tuple[int, ...], None]) -> tuple[int, ...]:
elif shape is None and isinstance(self._default, np.ndarray):
# Case: shape is not given and default is multidimensional
return self._default.shape
elif self._is_legal_shape(shape):
elif shape is not None and self._is_legal_shape(shape):
# Case: shape is given
return shape
else:
raise Exception(
f"Illegal shape {shape}. The argument should be in the format (lower, upper), where lower < upper!")
f"Illegal shape {shape}. The argument should be in the format (dim1, ...)!")

def _is_legal_shape(self, shape: tuple[int, ...]) -> bool:
if shape == (1,) and isinstance(self._default, (int, np.int_)):
# Case: shape and default refers to single dimensional
return True
elif isinstance(shape, tuple) and all(isinstance(s, (int, np.int_)) for s in shape) and \
isinstance(self._default, np.ndarray) and shape == self._default.shape:
return True
if shape == (1,):
# Case: shape refers to single dimensional
return isinstance(self._default, (int, np.int_))
elif isinstance(shape, tuple) and all(isinstance(s, (int, np.int_)) for s in shape):
# Case: shape refers to multidimensional
return isinstance(self._default, np.ndarray) and shape == self._default.shape
return False

def _check_distribution(self, distribution: Union[Distribution, None]) -> Distribution:
Expand Down Expand Up @@ -526,9 +511,8 @@ def sample(self, random: np.random.RandomState, size: Union[int, None] = None) -
raise Exception(f"Unknown Distribution {self._distribution}!")

def valid_configuration(self, value: Any) -> bool:
if isinstance(value, (list, np.ndarray)):
if isinstance(value, np.ndarray):
# Case: Value is multi-dimensional
value = np.array(value)
return np.all((self.lb <= value) & (value < self.ub)) and self._shape == value.shape
elif isinstance(value, (int, np.int_)):
# Case: value is single-dimensional
Expand All @@ -537,17 +521,15 @@ def valid_configuration(self, value: Any) -> bool:
return False

def adjust_configuration(self, value: Any) -> Any:
if isinstance(value, (list, np.ndarray)) and np.issubdtype(value.dtype, np.int_):
if isinstance(value, np.ndarray) and np.issubdtype(value.dtype, np.int_):
# Case: value is multidimensional
value = np.array(value)

# Do not exceed lower, upper bound
value[value < self.lb] = self.lb
value[value >= self.ub] = self.ub - 1

return value
elif isinstance(value, (int, np.int_)):
# Case: value is single dimensional
# Do not exceed lower, upper bound
if value < self.lb:
value = self.lb
elif value >= self.ub:
Expand All @@ -567,5 +549,9 @@ def __hash__(self) -> int:
return hash(self.__repr__())

def __repr__(self) -> str:
text = f"Integer({self._name}, bounds={self._bounds}, default={self._default}, shape={self._shape})"
return text
header = f"Integer({self._name}, "
bounds = f"bounds={self._bounds}, "
default = f"default={self._default}, "
shape = f"shape={self._shape}"
end = ")"
return "".join([header, bounds, default, shape, end])

0 comments on commit a1b8025

Please sign in to comment.