Skip to content

Commit

Permalink
D. Jabs:
Browse files Browse the repository at this point in the history
- Updated Categorical Hyperparameter, where the weights are now inside the distribution Choice() and not inside the class itself
- Updated unittests
  • Loading branch information
Dennis Jabs committed Nov 3, 2023
1 parent c21f82c commit ea2dac9
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 199 deletions.
51 changes: 45 additions & 6 deletions PyHyperparameterSpace/dist/categorical.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,59 @@
from typing import Union
import numpy as np


from PyHyperparameterSpace.dist.abstract_dist import Distribution


class Choice(Distribution):
"""
Class for representing a Categorical choice dist.
TODO: Refactor
"""

def __init__(self):
pass
def __init__(self, weights: Union[list[float], list[int], np.ndarray]):
self.weights = None
self.change_distribution(weights)

def change_distribution(self, weights: Union[list[float], list[int], np.ndarray]):
weights = np.array(weights)

assert weights.ndim == 1, f"Illegal weights {weights}. Argument should be a matrix of size (n,)!"
assert np.all(0.0 <= w for w in weights), \
f"Illegal weights {weights}. Each p inside weights should >= 0.0!"

# Normalize the weights
self.weights = self._normalize(weights)

@classmethod
def _normalize(cls, p: Union[list[float], np.ndarray]) -> Union[list[float], np.ndarray]:
"""
Normalizes the given probability distribution, so that sum(p)=1.
Args:
p (Union[list[float], np.ndarray]):
Non-normalized probability distribution
Returns:
Union[list[float], np.ndarray]:
Normalized probability distribution
"""
assert all(0.0 <= prob for prob in p), \
"The given non-normalized dist p cannot contain negative values!"

if isinstance(p, list):
result_type = list
else:
result_type = np.array

def change_distribution(**kwargs):
raise Exception("Illegal call of change_distribution(). Choice distribution cannot be changed!")
sum_p = np.sum(p)
if sum_p == 1:
# Case: p is already normalized
return result_type(p)
# Case: p should be normalized
return result_type([prob / sum_p for prob in p])

def __str__(self):
return "Choice()"
return f"Choice(weights={self.weights})"

def __repr__(self):
return self.__str__()
28 changes: 0 additions & 28 deletions PyHyperparameterSpace/hp/abstract_hp.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,31 +220,3 @@ def _get_sample_size(
else:
# Case: shape is a tuple
return size, *shape

@classmethod
def _normalize(cls, p: Union[list[float], np.ndarray]) -> Union[list[float], np.ndarray]:
"""
Normalizes the given probability distribution, so that sum(p)=1.
Args:
p (Union[list[float], np.ndarray]):
Non-normalized probability distribution
Returns:
Union[list[float], np.ndarray]:
Normalized probability distribution
"""
assert all(0.0 <= prob for prob in p), \
"The given non-normalized dist p cannot contain negative values!"

if isinstance(p, list):
result_type = list
else:
result_type = np.array

sum_p = np.sum(p)
if sum_p == 1:
# Case: p is already normalized
return result_type(p)
# Case: p should be normalized
return result_type([prob / sum_p for prob in p])
92 changes: 17 additions & 75 deletions PyHyperparameterSpace/hp/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@ class Categorical(Hyperparameter):
default (Any):
Default value of the hyperparameter
distribution (Distribution):
distribution (Union[Distribution, None]):
Distribution from where we sample new values for hyperparameter
weights (Union[list[int], list[float], None]):
Probability distribution for each possible discrete value
"""

def __init__(
Expand All @@ -33,26 +30,20 @@ def __init__(
choices: list[Any],
default: Union[Any, None] = None,
shape: Union[int, tuple[int, ...], None] = None,
distribution: Distribution = Choice(),
weights: Union[list[int], list[float], None] = None,
distribution: Union[Distribution, None] = None,
):
if isinstance(choices, list):
choices = np.array(choices)

if isinstance(weights, list):
weights = np.array(weights)

# First set the variables
self._choices = choices
self._distribution = distribution
self._weights = weights

super().__init__(name=name, shape=shape, default=default)

# Then check the variables and set them again
self._choices = self._check_choices(choices)
self._distribution = self._check_distribution(distribution)
self._weights = self._check_weights(weights)

def get_choices(self) -> list[str]:
"""
Expand All @@ -70,14 +61,6 @@ def get_distribution(self) -> Distribution:
"""
return self._distribution

def get_weights(self) -> list[float]:
"""
Returns:
list[float]:
List of weights for each choice
"""
return self._weights

def _check_choices(self, choices: list[Any]) -> list[Any]:
"""
Checks if the given choices are legal. A choice is called legal, if it fulfills the format [item1, item2, ...]
Expand Down Expand Up @@ -114,9 +97,9 @@ def _is_legal_choices(self, choices: Union[list[Any], None]) -> bool:

def _check_default(self, default: Union[Any, None]) -> Any:
if default is None:
if self._weights is not None:
if self._distribution is not None:
# Case: Take the option with the highest probability as default value
return self._choices[np.argmax(self._weights)]
return self._choices[np.argmax(self._distribution.weights)]
else:
# Case: Take the first option as default value
return self._choices[0]
Expand Down Expand Up @@ -162,20 +145,24 @@ def _is_legal_shape(self, shape: Union[int, tuple[int, ...]]) -> bool:
return True
return False

def _check_distribution(self, distribution: Distribution) -> Distribution:
def _check_distribution(self, distribution: Union[Distribution, None]) -> Distribution:
"""
Checks if the distribution is legal. A distribution is called legal, if the class of the distribution can be
used for the given hyperparameter class.
Args:
distribution (Distribution):
distribution (Union[Distribution, None]):
Distribution to check
Returns:
Distribution:
Legal distribution
"""
if self._is_legal_distribution(distribution):
if distribution is None:
# Case: Distribution is not given
return Choice(weights=np.ones(len(self._choices)))
elif self._is_legal_distribution(distribution):
# Case: Distribution is given and legal
return distribution
else:
raise Exception(f"Illegal distribution {distribution}!")
Expand All @@ -186,58 +173,14 @@ def _is_legal_distribution(self, distribution: Distribution) -> bool:
Args:
distribution (Distribution):
distribution to check
Distribution to check
Returns:
bool:
True if the given distribution can be used for the hyperparameter class
"""
if isinstance(distribution, Choice):
return True
return False

def _check_weights(self, weights: Union[list[int], list[float], np.ndarray, None]) -> np.ndarray:
"""
Checks if the given weights are legal. Weights are called legal, if (...)
- fulfills the right format [w1, w2, ...]
- length of weights and choices are equal
- for all w_i >= 0
and normalizes the weights to a probability distribution.
Args:
weights (Union[list[int], list[float], None]):
Weights to check
Returns:
np.ndarray:
Normalized weights
"""
if weights is None:
return Categorical._normalize(np.array([1 for _ in range(len(self._choices))]))
elif self._is_legal_weights(weights):
return Categorical._normalize(np.array(weights))
else:
raise Exception(f"Illegal weights {weights}!")

def _is_legal_weights(self, weights: Union[list[int], list[float], np.ndarray]) -> bool:
"""
Returns True if the given weights (...)
- fulfills the right format [w1, w2, ...]
- length of weights and choices are equal
- for all w_i >= 0
Args:
weights (Union[list[int], list[float], np.ndarray]):
Weights to check
Returns:
bool:
True if weights are legal
"""
if isinstance(weights, (list, np.ndarray)) and len(weights) == len(self._choices) and \
all(0 <= w for w in weights):
return True
return np.isclose(sum(distribution.weights), 1) and distribution.weights.shape == (self._choices.shape[0],)
return False

def change_distribution(self, **kwargs):
Expand All @@ -248,12 +191,13 @@ def change_distribution(self, **kwargs):
**kwargs (dict):
Parameters that defines the distribution
"""
self._weights = self._check_weights(weights=kwargs["weights"])
self._distribution.change_distribution(**kwargs)
self._check_distribution(self._distribution)

def sample(self, random: np.random.RandomState, size: Union[int, None] = None) -> Any:
if isinstance(self._distribution, Choice):
# Case: Sample from given distribution (with weights)
indices = random.choice(len(self._choices), size=size, replace=True, p=self._weights)
indices = random.choice(len(self._choices), size=size, replace=True, p=self._distribution.weights)
if isinstance(indices, int):
# Case: Only a single sample should be returned
if len(self._shape) > 1:
Expand Down Expand Up @@ -283,18 +227,16 @@ def __hash__(self) -> int:
return hash(self.__repr__())

def __repr__(self) -> str:
text = f"Categorical({self._name}, choices={self._choices}, default={self._default}, weights={self._weights})"
text = f"Categorical({self._name}, choices={self._choices}, default={self._default}, distribution={self._distribution})"
return text

def __getstate__(self) -> dict:
state = super().__getstate__()
state["choices"] = self._choices
state["distribution"] = self._distribution
state["weights"] = self._weights
return state

def __setstate__(self, state) -> dict:
super().__setstate__(state)
self._choices = state["choices"]
self._distribution = state["distribution"]
self._weights = state["weights"]
50 changes: 0 additions & 50 deletions tests/hp/test_abstract_hp.py

This file was deleted.

Loading

0 comments on commit ea2dac9

Please sign in to comment.