Skip to content

Commit

Permalink
D. Jabs:
Browse files Browse the repository at this point in the history
- Changed check_default, where shape=None and default=None is given an error now
  • Loading branch information
Dennis Jabs committed Nov 15, 2023
1 parent 3113b01 commit c5609d0
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 20 deletions.
2 changes: 1 addition & 1 deletion PyHyperparameterSpace/hp/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def sample(self, random: np.random.RandomState, size: Union[int, None] = None) -
return self._choices[indices]
return np.array([self._choices[idx] for idx in indices])
else:
raise Exception(f"Unknown Probability Distribution {self._distribution}!")
raise Exception(f"Unknown Distribution {self._distribution}!")

def valid_configuration(self, value: Any) -> bool:
if isinstance(value, (list, np.ndarray)):
Expand Down
34 changes: 15 additions & 19 deletions PyHyperparameterSpace/hp/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,12 @@ def _is_legal_bounds(self, bounds: Union[tuple[float, float], tuple[int, int]]):
return False

def _check_default(self, default: Union[int, float, np.ndarray]) -> Union[int, float, np.ndarray]:
if default is None:
# Case: default is not given
if self._shape is None or self._shape == 1 or self._shape == (1,):
# Case: default should be single dimensional
return (self.lb + self.ub) / 2
else:
# Case: default should be multidimensional
return np.full(shape=self._shape, fill_value=((self.lb + self.ub) / 2))
if default is None and self._shape == (1,):
# Case: default is not given and shape refers to single dimensional values
return (self.lb + self.ub) / 2
elif default is None and self._shape is not None and all(isinstance(s, (int, np.int_)) for s in self._shape):
# Case: default is not given and shape refers to multi dimensional values
return np.full(shape=self._shape, fill_value=((self.lb + self.ub) / 2))
elif self._is_legal_default(default):
return default
else:
Expand Down Expand Up @@ -342,7 +340,7 @@ def sample(self, random: np.random.RandomState, size: Union[int, None] = None) -
sample = random.uniform(low=self.lb, high=self.ub, size=sample_size)
return sample
else:
raise Exception("Unknown Distribution!")
raise Exception(f"Unknown Distribution {self._distribution}!")

def valid_configuration(self, value: Any) -> bool:
if isinstance(value, (list, np.ndarray)):
Expand Down Expand Up @@ -412,20 +410,18 @@ def _is_legal_bounds(self, bounds: tuple[int, int]) -> bool:
return False

def _check_default(self, default: Union[int, np.ndarray]) -> Union[int, np.ndarray]:
if default is None:
# Case: default is not given
if self._shape is None or self._shape == 1 or self._shape == (1,):
# Case: shape refers to single dimensional
return int((self.lb + self.ub) / 2)
else:
# Case: Shape refers to multidimensional
return np.full(shape=self._shape, fill_value=int((self.lb + self.ub) / 2))
if default is None and self._shape == (1,):
# Case: default is not given and shape refers to single dimensional values
return int((self.lb + self.ub) / 2)
elif default is None and self._shape is not None and all(isinstance(s, (int, np.int_)) for s in self._shape):
# Case: default is not given and shape refers to multidimensional values
return np.full(shape=self._shape, fill_value=int((self.lb + self.ub) / 2))
elif self._is_legal_default(default):
# Case: default value is legal
return default
else:
# Case: default value is illegal
raise Exception(f"Illegal default value {default}!")
raise Exception(f"Illegal default {default}. The argument should be in between the bounds (lower, upper)!")

def _is_legal_default(self, default: Union[int, np.ndarray]) -> bool:
if not isinstance(default, int) and \
Expand Down Expand Up @@ -479,7 +475,7 @@ def sample(self, random: np.random.RandomState, size: Union[int, None] = None) -
sample = random.randint(low=self.lb, high=self.ub, size=sample_size)
return sample
else:
raise Exception("Unknown Distribution!")
raise Exception(f"Unknown Distribution {self._distribution}!")

def valid_configuration(self, value: Any) -> bool:
if isinstance(value, (list, np.ndarray)):
Expand Down

0 comments on commit c5609d0

Please sign in to comment.