diff --git a/iris/algorithms/algorithm.py b/iris/algorithms/algorithm.py new file mode 100644 index 0000000..ab985b4 --- /dev/null +++ b/iris/algorithms/algorithm.py @@ -0,0 +1,99 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Algorithm class for distributed blackbox optimization library.""" + +import abc +import pathlib +from typing import Any, Dict, Sequence, Union +from iris import worker_util +import numpy as np + + +PARAMS_TO_EVAL = "params_to_eval" +OBS_NORM_BUFFER_STATE = "obs_norm_buffer_state" +UPDATE_OBS_NORM_BUFFER = "update_obs_norm_buffer" + + +class BlackboxAlgorithm(abc.ABC): + """Base class for Blackbox optimization algorithms.""" + + def __init__(self, + num_suggestions: int, + random_seed: int, + num_evals: int = 50) -> None: + """Initializes the blackbox algorithm. + + Args: + num_suggestions: Number of suggestions to sample for blackbox function + evaluation. + random_seed: Seed for numpy random state. + num_evals: Number of times to evaluate blackbox function while reporting + performance of current parameters. + """ + self._num_suggestions = num_suggestions + self._num_evals = num_evals + self._np_random_state = np.random.RandomState(random_seed) + self._opt_params = np.empty(0) + + @property + def opt_params(self): + """Returns the optimizer parameters.""" + return self._opt_params + + @abc.abstractmethod + def initialize(self, state: Dict[str, Any]) -> None: + """Initializes the algorithm from initial worker state.""" + raise NotImplementedError( + "Should be implemented in derived classes for specific algorithms.") + + @abc.abstractmethod + def get_param_suggestions(self, + evaluate: bool = False) -> Sequence[Dict[str, Any]]: + """Suggests a list of inputs to evaluate the Blackbox function on.""" + raise NotImplementedError( + "Should be implemented in derived classes for specific algorithms.") + + @abc.abstractmethod + def process_evaluations(self, + eval_results: Sequence[worker_util.EvaluationResult]): + """Processes the list of Blackbox function evaluations return from workers. + + Args: + eval_results: List containing Blackbox function evaluations based on the + order in which the suggestion were sent. The value is a tuple of + suggestion evaluated and the result after evaluation. + """ + del eval_results + raise NotImplementedError( + "Should be implemented in derived classes for specific algorithms.") + + @property + def state(self): + return {PARAMS_TO_EVAL: self._opt_params} + + @state.setter + def state(self, new_state: Dict[str, Any]) -> None: + self._opt_params = new_state[PARAMS_TO_EVAL] + + def restore_state_from_checkpoint(self, new_state: Dict[str, Any]) -> None: + self.state = new_state[PARAMS_TO_EVAL] + + def maybe_save_custom_checkpoint(self, + state: Dict[str, Any], + checkpoint_path: Union[pathlib.Path, str] + ) -> None: + """If implemented, saves a custom checkpoint to checkpoint_path.""" + del state, checkpoint_path + return None diff --git a/iris/algorithms/ars_algorithm.py b/iris/algorithms/ars_algorithm.py new file mode 100644 index 0000000..62df636 --- /dev/null +++ b/iris/algorithms/ars_algorithm.py @@ -0,0 +1,642 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Algorithm class for Augmented Random Search Blackbox algorithm.""" + +import collections +import datetime +import math +import pathlib +import pickle as pkl +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from absl import logging +from flax import linen as nn +from iris import checkpoint_util +from iris import normalizer +from iris import worker_util +from iris.algorithms import algorithm +from iris.algorithms import stateless_perturbation_generators +import jax +import jax.numpy as jnp +import numpy as np + +PRNGKey = jax.Array + +_DUMMY_REWARD = -1_000_000_000.0 + + +class MLP(nn.Module): + """Defines an MLP model for learned hyper-params.""" + + hidden_sizes: Sequence[int] = (32, 16) + output_size: int = 2 + + @nn.compact + def __call__(self, x: jnp.ndarray, state: Any): + for feat in self.hidden_sizes: + x = nn.Dense(feat)(x) + x = nn.tanh(x) + x = nn.Dense(self.output_size)(x) + return nn.sigmoid(x), state + + def initialize_carry(self, rng: PRNGKey, params: jnp.ndarray) -> Any: + del rng, params + return None + + +class AugmentedRandomSearch(algorithm.BlackboxAlgorithm): + """Augmented random search algorithm for blackbox optimization.""" + + def __init__(self, + std: float | Callable[[int], float], + step_size: float | Callable[[int], float], + top_percentage: float = 1.0, + orthogonal_suggestions: bool = False, + quasirandom_suggestions: bool = False, + top_sort_type: str = "max", + obs_norm_data_buffer: Optional[normalizer.MeanStdBuffer] = None, + **kwargs) -> None: + """Initializes the augmented random search algorithm. + + Args: + std: Standard deviation for normal perturbations around current + optimization parameter vector. A std schedule as a function of iteration + number can also be given. + step_size: Step size for gradient ascent. A step size schedule as a + function of iteration number can also be given. + top_percentage: Fraction of top performing perturbations to use for + gradient estimation. + orthogonal_suggestions: Whether to orthogonalize the perturbations. + quasirandom_suggestions: Whether quasirandom perturbations should be used; + valid only if orthogonal_suggestions = True. + top_sort_type: How to sort evaluation results for selecting top + directions. Valid options are: "max" and "diff". + obs_norm_data_buffer: Buffer to sync statistics from all workers for + online mean std observation normalizer. + **kwargs: Other keyword arguments for base class. + """ + super().__init__(**kwargs) + self._iteration = 0 + self._std = std + self._step_size = step_size + self._last_std_used = 1.0 + self._num_top = int(top_percentage * self._num_suggestions) + self._num_top = max(1, self._num_top) + self._orthogonal_suggestions = orthogonal_suggestions + self._quasirandom_suggestions = quasirandom_suggestions + self._top_sort_type = top_sort_type + self._obs_norm_data_buffer = obs_norm_data_buffer + + def initialize(self, state: Dict[str, Any]) -> None: + """Initializes the algorithm from initial worker state.""" + self._opt_params = state["init_params"] + + # Initialize Observation normalization buffer with init data from the worker + if self._obs_norm_data_buffer is not None: + self._obs_norm_data_buffer.data = state["obs_norm_buffer_data"] + + def process_evaluations( + self, eval_results: Sequence[worker_util.EvaluationResult]) -> None: + """Processes the list of Blackbox function evaluations return from workers. + + Gradient is computed by taking a weighted sum of directions and + difference of their value from the current value. The current parameter + vector is then updated in the gradient direction with specified step size. + + Args: + eval_results: List containing Blackbox function evaluations based on the + order in which the suggestions were sent. ARS performs antithetic + gradient estimation. The suggestions are sent for evaluation in pairs. + The eval_results list should contain an even number of entries with the + first half entries corresponding to evaluation result of positive + perturbations and the last half corresponding to negative perturbations. + """ + + # Retrieve delta direction from the param suggestion sent for evaluation. + pos_eval_results = eval_results[:self._num_suggestions] + neg_eval_results = eval_results[self._num_suggestions:] + filtered_pos_eval_results = [] + filtered_neg_eval_results = [] + for (peval, neval) in zip(pos_eval_results, neg_eval_results): + if (peval.params_evaluated.size) and ( + neval.params_evaluated.size): + filtered_pos_eval_results.append(peval) + filtered_neg_eval_results.append(neval) + params = np.array([r.params_evaluated for r in filtered_pos_eval_results]) + directions = (params - self._opt_params) / self._last_std_used + + eval_results = filtered_pos_eval_results + filtered_neg_eval_results + + # Get top evaluation results + pos_evals = np.array([r.value for r in filtered_pos_eval_results]) + neg_evals = np.array([r.value for r in filtered_neg_eval_results]) + if self._top_sort_type == "max": + max_evals = np.max(np.vstack([pos_evals, neg_evals]), axis=0) + else: + max_evals = np.abs(pos_evals - neg_evals) + idx = (-max_evals).argsort()[:self._num_top] + pos_evals = pos_evals[idx] + neg_evals = neg_evals[idx] + all_top_evals = np.hstack([pos_evals, neg_evals]) + evals = pos_evals - neg_evals + + # Get delta directions corresponding to top evals + directions = directions[idx, :] + + # Estimate gradients + gradient = np.dot(evals, directions) / evals.shape[0] + if not np.isclose(np.std(all_top_evals), 0.0): + gradient /= np.std(all_top_evals) + + # Apply gradients + step_size = self._step_size + if callable(self._step_size): + step_size = self._step_size(self._iteration) + self._iteration += 1 + self._opt_params += step_size * gradient + + if self._obs_norm_data_buffer is not None: + for r in eval_results: + self._obs_norm_data_buffer.merge(r.obs_norm_buffer_data) + + def get_param_suggestions(self, + evaluate: bool = False) -> Sequence[Dict[str, Any]]: + """Suggests a list of inputs to evaluate the Blackbox function on. + + Suggestions are sampled from a gaussian distribution around the current + parameter vector. For each suggestion, a dict containing keyword arguments + for the worker is sent. + + Args: + evaluate: Whether to evaluate current optimization variables for reporting + training progress. + + Returns: + A list of suggested inputs for the workers to evaluate. + """ + if evaluate: + param_suggestions = [self._opt_params] * self._num_evals + else: + dimensions = self._opt_params.shape[0] + if self._orthogonal_suggestions: + if self._quasirandom_suggestions: + param_suggestions = ( + stateless_perturbation_generators.RandomHadamardMatrixGenerator( + self._num_suggestions, dimensions + ).generate_matrix() + ) + else: + # We generate random iid perturbations and orthogonalize them. In the + # case when the number of suggestions to be generated is greater than + # param dimensionality, we generate multiple orthogonal perturbation + # blocks. Rows are othogonal within a block but not across blocks. + ortho_pert_blocks = [] + for _ in range(math.ceil(float(self._num_suggestions / dimensions))): + perturbations = self._np_random_state.normal( + 0, 1, (self._num_suggestions, dimensions)) + ortho_matrix, _ = np.linalg.qr(perturbations.T) + ortho_pert_blocks.append(np.sqrt(dimensions) * ortho_matrix.T) + param_suggestions = np.vstack(ortho_pert_blocks) + param_suggestions = param_suggestions[:self._num_suggestions, :] + else: + param_suggestions = self._np_random_state.normal( + 0, 1, (self._num_suggestions, dimensions)) + self._last_std_used = self._std + if callable(self._std): + self._last_std_used = self._std(self._iteration) + param_suggestions = np.vstack([ + self._opt_params + self._last_std_used * param_suggestions, + self._opt_params - self._last_std_used * param_suggestions + ]) + + suggestions = [] + for params in param_suggestions: + suggestion = {"params_to_eval": params} + if self._obs_norm_data_buffer is not None: + suggestion["obs_norm_state"] = self._obs_norm_data_buffer.state + suggestion["update_obs_norm_buffer"] = not evaluate + suggestions.append(suggestion) + return suggestions + + @property + def state(self) -> Dict[str, Any]: + return self._get_state() + + def _get_state(self) -> Dict[str, Any]: + state = {"params_to_eval": self._opt_params} + if self._obs_norm_data_buffer is not None: + state["obs_norm_state"] = self._obs_norm_data_buffer.state + return state + + @state.setter + def state(self, new_state: Dict[str, Any]) -> None: + self._set_state(new_state) + + def _set_state(self, new_state: Dict[str, Any]) -> None: + self._opt_params = new_state["params_to_eval"] + if self._obs_norm_data_buffer is not None: + self._obs_norm_data_buffer.state = new_state["obs_norm_state"] + + def restore_state_from_checkpoint(self, new_state: Dict[str, Any]) -> None: + self.state = new_state + + +class LearnableAugmentedRandomSearch(AugmentedRandomSearch): + """Learnable augmented random search algorithm for blackbox optimization.""" + + def __init__( + self, + model: Callable[[], nn.Module] = MLP, + model_path: Optional[str] = None, + top_percentage: float = 1.0, + orthogonal_suggestions: bool = False, + quasirandom_suggestions: bool = False, + top_sort_type: str = "max", + obs_norm_data_buffer: Optional[normalizer.MeanStdBuffer] = None, + seed: int = 42, + reward_buffer_size: int = 10, + **kwargs, + ) -> None: + """Initializes the learnable augmented random search algorithm. + + Args: + model: The model class to use when loading the meta-policy. + model_path: The checkpoint path to load the meta-policy from. + top_percentage: Fraction of top performing perturbations to use for + gradient estimation. + orthogonal_suggestions: Whether to orthogonalize the perturbations. + quasirandom_suggestions: Whether quasirandom perturbations should be used; + valid only if orthogonal_suggestions = True. + top_sort_type: How to sort evaluation results for selecting top + directions. Valid options are: "max" and "diff". + obs_norm_data_buffer: Buffer to sync statistics from all workers for + online mean std observation normalizer. + seed: The seed to use. + reward_buffer_size: the size of the reward buffer that stores a history of + rewards. + **kwargs: Other keyword arguments for base class. + """ + super().__init__(**kwargs) + super().__init__(**kwargs) + self._iteration = 0 + self._seed = seed + self._model_path = model_path + self._model = model() + self._last_std_used = 1.0 + self._num_top = int(top_percentage * self._num_suggestions) + self._num_top = max(1, self._num_top) + self._orthogonal_suggestions = orthogonal_suggestions + self._quasirandom_suggestions = quasirandom_suggestions + self._top_sort_type = top_sort_type + self._obs_norm_data_buffer = obs_norm_data_buffer + self._tree_weights = None + self._reward_buffer_size = reward_buffer_size + self._reward_buffer = collections.deque(maxlen=self._reward_buffer_size) + self._populate_reward_buffer() + self._step_size = 0.02 + self._std = 1.0 + + def _populate_reward_buffer(self): + """Populate reward buffer with very negative values.""" + self._reward_buffer.extend([_DUMMY_REWARD] * self._reward_buffer_size) + + def _restore_state_from_checkpoint(self, logdir: str): + try: + state = checkpoint_util.load_checkpoint_state(logdir) + iteration = 0 # No iteration information is extracted + return state, iteration + except ValueError: + logging.warning( + "Failed to load directly as a checkpoint, try searching subfolders" + " with checkpoints." + ) + return None, 0 + + def get_param_suggestions( + self, evaluate: bool = False + ) -> Sequence[Dict[str, Any]]: + """Suggests a list of inputs to evaluate the Blackbox function on. + + Suggestions are sampled from a gaussian distribution around the current + parameter vector. For each suggestion, a dict containing keyword arguments + for the worker is sent. + + Args: + evaluate: Whether to evaluate current optimization variables for reporting + training progress. + + Returns: + A list of suggested inputs for the workers to evaluate. + """ + if evaluate: + param_suggestions = [self._opt_params] * self._num_evals + else: + dimensions = self._opt_params.shape[0] + if self._orthogonal_suggestions: + if self._quasirandom_suggestions: + param_suggestions = ( + stateless_perturbation_generators.RandomHadamardMatrixGenerator( + self._num_suggestions, dimensions + ).generate_matrix() + ) + else: + # We generate random iid perturbations and orthogonalize them. In the + # case when the number of suggestions to be generated is greater than + # param dimensionality, we generate multiple orthogonal perturbation + # blocks. Rows are othogonal within a block but not across blocks. + ortho_pert_blocks = [] + for _ in range(math.ceil(float(self._num_suggestions / dimensions))): + perturbations = self._np_random_state.normal( + 0, 1, (self._num_suggestions, dimensions) + ) + ortho_matrix, _ = np.linalg.qr(perturbations.T) + ortho_pert_blocks.append(np.sqrt(dimensions) * ortho_matrix.T) + param_suggestions = np.vstack(ortho_pert_blocks) + param_suggestions = param_suggestions[: self._num_suggestions, :] + else: + param_suggestions = self._np_random_state.normal( + 0, 1, (self._num_suggestions, dimensions) + ) + self._last_std_used = self._std + param_suggestions = np.vstack([ + self._opt_params, + self._opt_params + self._last_std_used * param_suggestions, + self._opt_params - self._last_std_used * param_suggestions, + ]) + + suggestions = [] + for params in param_suggestions: + suggestion = {"params_to_eval": params} + if self._obs_norm_data_buffer is not None: + suggestion["obs_norm_state"] = self._obs_norm_data_buffer.state + suggestion["update_obs_norm_buffer"] = not evaluate + suggestions.append(suggestion) + return suggestions + + def process_evaluations( + self, eval_results: Sequence[worker_util.EvaluationResult] + ) -> None: + + self._reward_buffer.append(eval_results[0].value) + rewards = np.asarray(self._reward_buffer) + model_input = np.concatenate([[self._iteration], rewards]) + + if self._tree_weights is None: + self._state = self._restore_state_from_checkpoint(self._model_path) + self._tree_weights = self._model.init( + jax.random.PRNGKey(seed=self._seed), model_input, self._state + ) + + hyper_params, self._state = self._model.apply( + self._tree_weights, model_input, self._state + ) + step_size, std = hyper_params + self._step_size = step_size + self._std = std + super().process_evaluations(eval_results) + + +class MultiAgentAugmentedRandomSearch(AugmentedRandomSearch): + """Augmented random search algorithm for blackbox optimization.""" + + def __init__(self, + std: float, + step_size: float, + top_percentage: float = 1.0, + orthogonal_suggestions: bool = False, + quasirandom_suggestions: bool = False, + top_sort_type: str = "max", + obs_norm_data_buffer: Optional[normalizer.MeanStdBuffer] = None, + agent_keys: Optional[List[str]] = None, + restore_state_from_single_agent: bool = False, + **kwargs) -> None: + """Initializes the augmented random search algorithm for multi-agent training. + + Args: + std: Standard deviation for normal perturbations around current + optimization parameter vector. + step_size: Step size for gradient ascent. + top_percentage: Fraction of top performing perturbations to use for + gradient estimation. + orthogonal_suggestions: Whether to orthogonalize the perturbations. + quasirandom_suggestions: Whether quasirandom perturbations should be used; + valid only if orthogonal_suggestions = True. + top_sort_type: How to sort evaluation results for selecting top + directions. Valid options are: "max" and "diff". + obs_norm_data_buffer: Buffer to sync statistics from all workers for + online mean std observation normalizer. + agent_keys: List of keys which uniquely identify the agents. The ordering + needs to be consistent across the algorithm, policy, and worker. + restore_state_from_single_agent: if True then when + restore_state_from_checkpoint is called the state is duplicated + self._num_agents times. + **kwargs: Other keyword arguments for base class. + """ + super().__init__(std=std, + step_size=step_size, + top_percentage=top_percentage, + orthogonal_suggestions=orthogonal_suggestions, + quasirandom_suggestions=quasirandom_suggestions, + top_sort_type=top_sort_type, + obs_norm_data_buffer=obs_norm_data_buffer, + **kwargs) + if agent_keys is None: + self._agent_keys = ["arm", "opp"] + else: + self._agent_keys = agent_keys + self._num_agents = len(self._agent_keys) + self._restore_state_from_single_agent = restore_state_from_single_agent + + def _split_params(self, params: np.ndarray) -> List[np.ndarray]: + return np.array_split(params, self._num_agents) + + def _combine_params(self, params_per_agents: List[np.ndarray]) -> np.ndarray: + return np.concatenate(params_per_agents, axis=0) + + def restore_state_from_checkpoint(self, new_state: Dict[str, Any]) -> None: + logging.info("Restore: restore from 1 agent: %d", + self._restore_state_from_single_agent) + logging.info("Restore: num_agents: %d", self._num_agents) + logging.info("Restore: new state keys: %s", list(new_state.keys())) + logging.info("Restore: new_state params shape: %s", + new_state["params_to_eval"].shape) + + # Initialize multiple agents from a single agent. + if self._restore_state_from_single_agent: + if new_state["params_to_eval"].ndim != 1: + raise ValueError( + f"Params to eval has {new_state['params_to_eval'].ndim} dims, " + "should only have 1." + ) + duplicated_state = { + "params_to_eval": + np.tile(new_state["params_to_eval"], self._num_agents) + } + if self._obs_norm_data_buffer is not None: + duplicated_state["obs_norm_state"] = {} + duplicated_state["obs_norm_state"]["mean"] = np.tile( + new_state["obs_norm_state"]["mean"], self._num_agents) + duplicated_state["obs_norm_state"]["std"] = np.tile( + new_state["obs_norm_state"]["std"], self._num_agents) + duplicated_state["obs_norm_state"]["n"] = ( + new_state["obs_norm_state"]["n"]) + + self.state = duplicated_state + logging.info("Restore: duplicated states params shape: %s", + duplicated_state["params_to_eval"].shape) + + # Initialize one agent from a single agent. + else: + self.state = new_state + + logging.info("Restored state: params shape: %s, opt params shape: %s, " + "obs norm state: %s", + self.state["params_to_eval"].shape, + self._opt_params.shape, + self.state.get("obs_norm_state", None)) + if self._obs_norm_data_buffer is not None: + logging.info("Restored state: obs norm mean shape: %s, std shape: %s", + self.state["obs_norm_state"]["mean"].shape, + self.state["obs_norm_state"]["std"].shape) + + def maybe_save_custom_checkpoint(self, + state: Dict[str, Any], + checkpoint_path: Union[pathlib.Path, str] + ) -> None: + """Saves a checkpoint per agent with prefix checkpoint_path.""" + agent_params = self._split_params(state["params_to_eval"]) + for i in range(self._num_agents): + per_agent_state = {} + per_agent_state["params_to_eval"] = agent_params[i] + if self._obs_norm_data_buffer is not None: + obs_norm_state = state["obs_norm_state"] + elems_per_agent = int( + obs_norm_state["mean"].shape[-1] / self._num_agents) + per_agent_state["obs_norm_state"] = {} + start_idx = i * elems_per_agent + end_idx = (i + 1) * elems_per_agent + if obs_norm_state["mean"].ndim == 1: + per_agent_state["obs_norm_state"]["mean"] = ( + obs_norm_state["mean"][start_idx: end_idx]) + per_agent_state["obs_norm_state"]["std"] = ( + obs_norm_state["std"][start_idx: end_idx]) + else: + per_agent_state["obs_norm_state"]["mean"] = ( + obs_norm_state["mean"][:, start_idx: end_idx]) + per_agent_state["obs_norm_state"]["std"] = ( + obs_norm_state["std"][:, start_idx: end_idx]) + per_agent_state["obs_norm_state"]["n"] = obs_norm_state["n"] + agent_checkpoint_path = f"{checkpoint_path}_agent_{i}" + logging.info("Saving agent checkpoints to %s...", agent_checkpoint_path) + self.save_checkpoint_internal( + agent_checkpoint_path, per_agent_state + self.save_checkpoint_oss(agent_checkpoint_path, per_agent_state) + + def save_checkpoint_oss(self, checkpoint_path: str, state: Any) -> None: + with open(checkpoint_path, "wb") as f: + pkl.dump(state, f) + + def split_and_save_checkpoint(self, checkpoint_path: str) -> None: + state = checkpoint_util.load_checkpoint_state(checkpoint_path) + self.maybe_save_custom_checkpoint(state=state, + checkpoint_path=checkpoint_path) + + def _get_top_evaluation_results( + self, + agent_key: str, + pos_eval_results: Sequence[worker_util.EvaluationResult], + neg_eval_results: Sequence[worker_util.EvaluationResult] + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + pos_evals = np.array( + [r.metrics[f"reward_{agent_key}"] for r in pos_eval_results]) + neg_evals = np.array( + [r.metrics[f"reward_{agent_key}"] for r in neg_eval_results]) + if self._top_sort_type == "max": + max_evals = np.max(np.vstack([pos_evals, neg_evals]), axis=0) + elif self._top_sort_type == "diff": + max_evals = np.abs(pos_evals - neg_evals) + idx = (-max_evals).argsort()[:self._num_top] + pos_evals = pos_evals[idx] + neg_evals = neg_evals[idx] + return pos_evals, neg_evals, idx + + def process_evaluations( + self, eval_results: Sequence[worker_util.EvaluationResult]) -> None: + """Processes the list of Blackbox function evaluations return from workers. + + Gradient is computed by taking a weighted sum of directions and + difference of their value from the current value. The current parameter + vector is then updated in the gradient direction with specified step size. + + Args: + eval_results: List containing Blackbox function evaluations based on the + order in which the suggestions were sent. ARS performs antithetic + gradient estimation. The suggestions are sent for evaluation in pairs. + The eval_results list should contain an even number of entries with the + first half entries corresponding to evaluation result of positive + perturbations and the last half corresponding to negative perturbations. + """ + + # Retrieve delta direction from the param suggestion sent for evaluation. + pos_eval_results = eval_results[:self._num_suggestions] + neg_eval_results = eval_results[self._num_suggestions:] + filtered_pos_eval_results = [] + filtered_neg_eval_results = [] + for i in range(len(pos_eval_results)): + if (pos_eval_results[i].params_evaluated.size) and ( + neg_eval_results[i].params_evaluated.size): + filtered_pos_eval_results.append(pos_eval_results[i]) + filtered_neg_eval_results.append(neg_eval_results[i]) + + params = np.array([r.params_evaluated for r in filtered_pos_eval_results]) + eval_results = filtered_pos_eval_results + filtered_neg_eval_results + + # This is length num pos results with splits per agent + eval_params_per_agent = [self._split_params(p) for p in params] + eval_params_per_agent = list(zip(*eval_params_per_agent)) + # This has length num agents with a 2d array with shape + # (num_pos_results, agent_params_dim). + eval_params_per_agent = [np.array(a) for a in eval_params_per_agent] + + current_params_per_agent = self._split_params(self._opt_params) + updated_params_per_agent = [] + for (agent_eval_params, agent_params, agent_key) in zip( + eval_params_per_agent, current_params_per_agent, self._agent_keys): + pos_evals, neg_evals, idx = self._get_top_evaluation_results( + agent_key=agent_key, + pos_eval_results=filtered_pos_eval_results, + neg_eval_results=filtered_neg_eval_results) + all_top_evals = np.hstack([pos_evals, neg_evals]) + evals = pos_evals - neg_evals + + # Get delta directions corresponding to top evals + directions = (agent_eval_params - agent_params) / self._std + directions = directions[idx, :] + + # Estimate gradients + gradient = np.dot(evals, directions) / evals.shape[0] + if not np.isclose(np.std(all_top_evals), 0.0): + gradient /= np.std(all_top_evals) + + # Apply gradients + updated_agent_params = agent_params + self._step_size * gradient + updated_params_per_agent.append(updated_agent_params) + + self._opt_params = self._combine_params(updated_params_per_agent) + + # Update the observation buffer + if self._obs_norm_data_buffer is not None: + for r in eval_results: + self._obs_norm_data_buffer.merge(r.obs_norm_buffer_data) diff --git a/iris/algorithms/ars_algorithm_test.py b/iris/algorithms/ars_algorithm_test.py new file mode 100644 index 0000000..349551d --- /dev/null +++ b/iris/algorithms/ars_algorithm_test.py @@ -0,0 +1,142 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from iris import normalizer +from iris import worker_util +from iris.algorithms import ars_algorithm +import numpy as np +import tensorflow as tf +from absl.testing import absltest +from absl.testing import parameterized + + +class AlgorithmTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.parameters( + (True, False), + (False, False), + ) + def test_ars_gradient(self, orthogonal_suggestions, quasirandom_suggestions): + algo = ars_algorithm.AugmentedRandomSearch( + num_suggestions=3, + step_size=0.5, + std=1., + top_percentage=1, + orthogonal_suggestions=orthogonal_suggestions, + quasirandom_suggestions=quasirandom_suggestions, + random_seed=7) + init_state = {'init_params': np.array([10., 10.])} + algo.initialize(init_state) + suggestions = algo.get_param_suggestions() + self.assertLen(suggestions, 6) + eval_results = [ + worker_util.EvaluationResult(np.array([10., 11.]), 10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.empty(0), 0), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 11.]), 10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 9.]), -10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 9.]), -10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.empty(0), 0), # pytype: disable=wrong-arg-types # numpy-scalars + ] + algo.process_evaluations(eval_results) + np.testing.assert_array_equal(algo._opt_params, np.array([10, 11])) + + def test_ars_gradient_with_schedule(self): + algo = ars_algorithm.AugmentedRandomSearch( + num_suggestions=3, + step_size=lambda x: x + 0.5, + std=lambda x: x + 1., + top_percentage=1, + random_seed=7) + init_state = {'init_params': np.array([10., 10.])} + algo.initialize(init_state) + suggestions = algo.get_param_suggestions() + self.assertLen(suggestions, 6) + eval_results = [ + worker_util.EvaluationResult(np.array([10., 11.]), 10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.empty(0), 0), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 11.]), 10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 9.]), -10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 9.]), -10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.empty(0), 0), # pytype: disable=wrong-arg-types # numpy-scalars + ] + algo.process_evaluations(eval_results) + np.testing.assert_array_equal(algo._opt_params, np.array([10, 11])) + + @parameterized.parameters( + ({'mean': np.asarray([1., 2.]), 'std': np.asarray([3., 4.]), 'n': 5},), + (None,), + ) + def test_restore_state_from_checkpoint(self, expected_obs_norm_state): + algo = ars_algorithm.AugmentedRandomSearch( + num_suggestions=3, + step_size=0.5, + std=1.0, + top_percentage=1, + orthogonal_suggestions=True, + quasirandom_suggestions=True, + obs_norm_data_buffer=normalizer.MeanStdBuffer() + if expected_obs_norm_state is not None + else None, + random_seed=7, + ) + init_state = {'init_params': np.array([10., 10.])} + if expected_obs_norm_state: + init_state['obs_norm_buffer_data'] = {'mean': np.asarray([0., 0.]), + 'std': np.asarray([1., 1.]), + 'n': 0} + algo.initialize(init_state) + + with self.subTest('init-mean'): + self.assertAllClose(np.array(algo._opt_params), + init_state['init_params']) + if expected_obs_norm_state is not None: + with self.subTest('init-obs-mean'): + self.assertAllClose( + np.asarray(algo._obs_norm_data_buffer.data['mean']), + np.asarray(init_state['obs_norm_buffer_data']['mean'])) + with self.subTest('init-obs-n'): + self.assertAllClose( + np.asarray(algo._obs_norm_data_buffer.data['n']), + np.asarray(init_state['obs_norm_buffer_data']['n'])) + with self.subTest('init-obs-std'): + self.assertAllClose( + np.asarray(algo._obs_norm_data_buffer.data['std']), + init_state['obs_norm_buffer_data']['std']) + + expected_restore_state = {'params_to_eval': np.array([5., 6.])} + if expected_obs_norm_state is not None: + expected_restore_state['obs_norm_state'] = expected_obs_norm_state + algo.restore_state_from_checkpoint(expected_restore_state) + + self.assertAllClose(algo._opt_params, + expected_restore_state['params_to_eval']) + if expected_obs_norm_state is not None: + std = expected_restore_state['obs_norm_state']['std'] + var = np.square(std) + expected_unnorm_var = var * 4 + with self.subTest('restore-obs-mean'): + self.assertAllClose( + np.asarray(algo._obs_norm_data_buffer.data['mean']), + np.asarray(expected_restore_state['obs_norm_state']['mean'])) + with self.subTest('restore-obs-n'): + self.assertAllClose( + np.asarray(algo._obs_norm_data_buffer.data['n']), + np.asarray(expected_restore_state['obs_norm_state']['n'])) + with self.subTest('restore-obs-std'): + self.assertAllClose( + np.asarray(algo._obs_norm_data_buffer.data['unnorm_var']), + expected_unnorm_var) + +if __name__ == '__main__': + absltest.main() diff --git a/iris/algorithms/cma_algorithm.py b/iris/algorithms/cma_algorithm.py new file mode 100644 index 0000000..80f656b --- /dev/null +++ b/iris/algorithms/cma_algorithm.py @@ -0,0 +1,152 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Algorithm class for Covariance Matrix Adaptation Evolutionary Strategy (CMA-ES) Blackbox algorithm.""" + +from typing import Any, Dict, Optional, Sequence + +import cma +from iris import normalizer +from iris import worker_util +from iris.algorithms import algorithm +import numpy as np + + +class CMAES(algorithm.BlackboxAlgorithm): + """CMA-ES for blackbox optimization. + + CMA-ES is a blackbox optimization algorithm that interleave between + sampling new candidate solutions according to a multi-variate gaussian and + updating the covariance matrix of the multi-variate gaussian based on the + history data. More details regarding CMA-ES can be found here: + https://arxiv.org/abs/1604.00772. In this code, we use the pycma package + to implement the algorithm. + """ + + def __init__(self, + std: float = 0.3, + bounds: Sequence[float] = (-1, 1), + obs_norm_data_buffer: Optional[normalizer.MeanStdBuffer] = None, + **kwargs) -> None: + """Initializes the augmented random search algorithm. + + Args: + std: Initial standard deviation to be used in CMA-ES. + bounds: Bounds of the search parameters. + obs_norm_data_buffer: Buffer to sync statistics from all workers for + online mean std observation normalizer. + **kwargs: Other keyword arguments for base class. + """ + super().__init__(**kwargs) + self._std = std + self._bounds = bounds + self._cmaes = cma.CMAEvolutionStrategy(np.empty(5), self._std, + { + "popsize": self._num_suggestions, + "bounds": list(self._bounds) + }) + self._obs_norm_data_buffer = obs_norm_data_buffer + + def initialize(self, state: Dict[str, Any]) -> None: + """Initializes the algorithm from initial worker state.""" + self._opt_params = state["init_params"] + + self._cmaes = cma.CMAEvolutionStrategy(self._opt_params, self._std, + { + "popsize": self._num_suggestions, + "bounds": list(self._bounds) + }) + # Initialize Observation normalization buffer with init data from the worker + if self._obs_norm_data_buffer is not None: + self._obs_norm_data_buffer.data = state["obs_norm_buffer_data"] + self._best_value = None + + def process_evaluations(self, + eval_results: Sequence[worker_util.EvaluationResult] + ) -> None: + """Processes the list of Blackbox function evaluations return from workers. + + Gradient is computed by taking a weighted sum of directions and + difference of their value from the current value. The current parameter + vector is then updated in the gradient direction with specified step size. + + Args: + eval_results: List containing Blackbox function evaluations based on the + order in which the suggestions were sent. + """ + + filtered_eval_results = [e for e in eval_results if e.params_evaluated.size] + all_params = np.array([r.params_evaluated for r in filtered_eval_results]) + all_values = np.array([r.value for r in filtered_eval_results]) + + if filtered_eval_results: + if self._best_value is None or np.max(all_values) > self._best_value: + self._best_value = np.max(all_values) + self._opt_params = np.copy(all_params[np.argmax(all_values)]) + + if len(all_params) == len(all_values) == self._num_suggestions: + self._cmaes.tell(all_params, -all_values) + + # Update the observation buffer + if self._obs_norm_data_buffer is not None: + for r in filtered_eval_results: + self._obs_norm_data_buffer.merge(r.obs_norm_buffer_data) + + def get_param_suggestions(self, + evaluate: bool = False + ) -> Sequence[Dict[str, Any]]: + """Suggests a list of inputs to evaluate the Blackbox function on. + + Suggestions are sampled from a gaussian distribution around the current + parameter vector. For each suggestion, a dict containing keyword arguments + for the worker is sent. + + Args: + evaluate: Whether to evaluate current optimization variables + for reporting training progress. + + Returns: + A list of suggested inputs for the workers to evaluate. + """ + if evaluate: + param_suggestions = [self._opt_params] * self._num_evals + else: + param_suggestions = self._cmaes.ask() + suggestions = [] + for params in param_suggestions: + suggestion = {"params_to_eval": params} + if self._obs_norm_data_buffer is not None: + suggestion["obs_norm_state"] = self._obs_norm_data_buffer.state + suggestion["update_obs_norm_buffer"] = not evaluate + suggestions.append(suggestion) + return suggestions + + @property + def state(self) -> Dict[str, Any]: + return self._get_state() + + def _get_state(self) -> Dict[str, Any]: + state = {"params_to_eval": self._opt_params} + if self._obs_norm_data_buffer is not None: + state["obs_norm_state"] = self._obs_norm_data_buffer.state + return state + + @state.setter + def state(self, new_state: Dict[str, Any]) -> None: + self._set_state(new_state) + + def _set_state(self, new_state: Dict[str, Any]) -> None: + self._opt_params = new_state["params_to_eval"] + if self._obs_norm_data_buffer is not None: + self._obs_norm_data_buffer.state = new_state["obs_norm_state"] diff --git a/iris/algorithms/cma_algorithm_test.py b/iris/algorithms/cma_algorithm_test.py new file mode 100644 index 0000000..d0ebb44 --- /dev/null +++ b/iris/algorithms/cma_algorithm_test.py @@ -0,0 +1,63 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from iris import worker_util +from iris.algorithms import cma_algorithm +import numpy as np +from absl.testing import absltest + +_TRUE_OPTIMAL = (-1, -1) + + +def test_fn(x): + """A simple quadrtic function to be maximized.""" + return -np.linalg.norm(x - np.array(_TRUE_OPTIMAL)) + + +class CMAAlgorithmTest(absltest.TestCase): + + def setUp(self): + super(CMAAlgorithmTest, self).setUp() + self.algo = cma_algorithm.CMAES( + num_suggestions=10, + std=0.5, + bounds=(-3, 3), + random_seed=7) + init_state = {'init_params': np.array([0., 0.])} + self.algo.initialize(init_state) + + def test_get_param_suggestions(self): + eval_suggestion_list = self.algo.get_param_suggestions(evaluate=True) + for eval_suggestion in eval_suggestion_list: + np.testing.assert_almost_equal(eval_suggestion['params_to_eval'], + np.array([0., 0.])) + + def test_cma_optimization(self): + for i in range(100): + suggestion_list = self.algo.get_param_suggestions(evaluate=False) + + eval_results = [] + for suggestion in suggestion_list: + eval_results.append( + worker_util.EvaluationResult( + np.array(suggestion['params_to_eval']), + test_fn(np.array(suggestion['params_to_eval'])))) + if i%10 == 0: + eval_results[0] = worker_util.EvaluationResult(np.empty(0), 0) # pytype: disable=wrong-arg-types # numpy-scalars + self.algo.process_evaluations(eval_results) + np.testing.assert_almost_equal(self.algo._opt_params, _TRUE_OPTIMAL) + + +if __name__ == '__main__': + absltest.main() diff --git a/iris/algorithms/controllers/__init__.py b/iris/algorithms/controllers/__init__.py new file mode 100644 index 0000000..5e5a6e1 --- /dev/null +++ b/iris/algorithms/controllers/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Loads all controllers.""" +from iris.algorithms.controllers import hill_climb_controller +from iris.algorithms.controllers import neat_controller +from iris.algorithms.controllers import policy_gradient_controller +from iris.algorithms.controllers import random_controller +from iris.algorithms.controllers import regularized_evolution_controller + +CONTROLLER_DICT = { + "hill_climb": + hill_climb_controller.HillClimbController, + "neat": + neat_controller.NEATController, + "policy_gradient": + policy_gradient_controller.PolicyGradientController, + "random_search": + random_controller.RandomController, + "regularized_evolution": + regularized_evolution_controller.RegularizedEvolutionController +} diff --git a/iris/algorithms/controllers/base_controller.py b/iris/algorithms/controllers/base_controller.py new file mode 100644 index 0000000..af4662f --- /dev/null +++ b/iris/algorithms/controllers/base_controller.py @@ -0,0 +1,82 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base class for all controllers in ES-ENAS.""" + +import abc +from typing import List, Optional + +import pyglove as pg + + +class BaseController(abc.ABC): + """Base class for all controllers in ES-ENAS.""" + + def __init__(self, dna_spec: pg.DNASpec, + batch_size: int) -> None: + """Initialization. + + Args: + dna_spec: A search space definition for the controller to use. + batch_size: Number suggestions in a current iteration. + """ + self._dna_spec = dna_spec + self._batch_size = batch_size + self._controller = pg.DNAGenerator() + self._history = [] + + def propose_dna(self) -> pg.DNA: + """Proposes a topology dna using stored template. + + Args: None. + + Returns: + dna: A proposed dna. + """ + return self._controller.propose() + + def collect_rewards_and_train(self, reward_vector: List[float], + dna_list: List[pg.DNA]): + """Collects rewards to update the controller. + + Args: + reward_vector: list of reward floats. + dna_list: list of dna's from the proposal function. + + Returns: + None. + """ + + for i, dna in enumerate(dna_list): + dna.reward = reward_vector[i] + self._controller.feedback(dna, dna.reward) + self._history.append((dna, dna.reward)) + + @abc.abstractmethod + def get_state(self) -> Optional[str]: + """Returns serialized version of controller algorithm state. + + Serialization is required for compatibility with iris states. + + Returns: + Serialized state in string format. + """ + + @abc.abstractmethod + def set_state(self, serialized_state: Optional[str] = None) -> None: + """Sets the controller algorithm state from a serialized state. + + Args: + serialized_state: State, serialized in string format. + """ diff --git a/iris/algorithms/controllers/controllers_test.py b/iris/algorithms/controllers/controllers_test.py new file mode 100644 index 0000000..f20e7e1 --- /dev/null +++ b/iris/algorithms/controllers/controllers_test.py @@ -0,0 +1,54 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for controllers.""" + +import random +from iris.algorithms.controllers import regularized_evolution_controller +import pyglove as pg +from absl.testing import absltest + + +class ControllersTest(absltest.TestCase): + + def test_checkpointing(self): + example_dna_spec = pg.dna_spec(pg.one_of(['a', 'b', 'c'])) + batch_size = 4 + seed = 0 + controller = regularized_evolution_controller.RegularizedEvolutionController( + example_dna_spec, batch_size=batch_size, seed=seed) + + for _ in range(20): + dna = controller.propose_dna() + controller.collect_rewards_and_train([random.random()], [dna]) + + another_controller = regularized_evolution_controller.RegularizedEvolutionController( + example_dna_spec, batch_size=batch_size, seed=seed) + another_controller.set_state(controller.get_state()) + + feedback_in_controller = [ + (dna, pg.evolution.base.get_fitness(dna)) + for dna in list(controller._controller._population) + ] + + feedback_in_another_controller = [ + (dna, pg.evolution.base.get_fitness(dna)) + for dna in list(another_controller._controller._population) + ] + self.assertEqual(feedback_in_controller, feedback_in_another_controller) + self.assertEqual(controller.propose_dna(), another_controller.propose_dna()) + + +if __name__ == '__main__': + absltest.main() diff --git a/iris/algorithms/controllers/hill_climb_controller.py b/iris/algorithms/controllers/hill_climb_controller.py new file mode 100644 index 0000000..50ec4bc --- /dev/null +++ b/iris/algorithms/controllers/hill_climb_controller.py @@ -0,0 +1,43 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""(Batch) Hill Climb Controller from PyGlove.""" +from typing import Optional +from iris.algorithms.controllers import base_controller +import pyglove as pg + + +class HillClimbController(base_controller.BaseController): + """Batch HillClimb Controller.""" + + def __init__(self, + dna_spec: pg.DNASpec, + batch_size: int, + seed: Optional[int] = None, + **kwargs): + """Initialization. See base class for more details.""" + + super().__init__(dna_spec, batch_size) + self._controller = pg.evolution.hill_climb( + pg.evolution.mutators.Uniform(), + batch_size=batch_size, init_population_size=1, seed=seed) # pytype: disable=wrong-arg-types # gen-stub-imports + self._controller.setup(self._dna_spec) + + def get_state(self): + # TODO: Add checkpointing logic for HillClimb. + return None + + def set_state(self, serialized_state): # pytype: disable=signature-mismatch # overriding-parameter-count-checks + # TODO: See above. + pass diff --git a/iris/algorithms/controllers/neat_controller.py b/iris/algorithms/controllers/neat_controller.py new file mode 100644 index 0000000..5e50dd2 --- /dev/null +++ b/iris/algorithms/controllers/neat_controller.py @@ -0,0 +1,45 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NEAT Controller from PyGlove.""" +from typing import Optional +from iris.algorithms.controllers import base_controller +import pyglove as pg + + +class NEATController(base_controller.BaseController): + """NEAT Controller.""" + + def __init__(self, + dna_spec: pg.DNASpec, + batch_size: int, + seed: Optional[int] = None, + **kwargs): + """Initialization. See base class for more details.""" + + super().__init__(dna_spec, batch_size) + population_size = self._batch_size + self._controller = pg.evolution.neat( + population_size=population_size, + mutator=pg.evolution.mutators.Uniform(), + seed=seed) # pytype: disable=wrong-arg-types # gen-stub-imports + self._controller.setup(self._dna_spec) + + def get_state(self): + # TODO: Add checkpointing logic for NEAT. + return None + + def set_state(self, serialized_state): # pytype: disable=signature-mismatch # overriding-parameter-count-checks + # TODO: See above. + pass diff --git a/iris/algorithms/controllers/policy_gradient_controller.py b/iris/algorithms/controllers/policy_gradient_controller.py new file mode 100644 index 0000000..b101e29 --- /dev/null +++ b/iris/algorithms/controllers/policy_gradient_controller.py @@ -0,0 +1,51 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This an updated variant of the original policy gradient-based MetaArchitect controller.""" +from iris.algorithms.controllers import base_controller +import pyglove as pg + + +class PolicyGradientController(base_controller.BaseController): + """Policy Gradient Controller.""" + + def __init__(self, + dna_spec: pg.DNASpec, + batch_size: int, + update_batch_size=64, + **kwargs): + """Initialization. See base class for more details.""" + + super().__init__(dna_spec, batch_size) + self._controller = pg.reinforcement_learning.PPO( # pytype: disable=module-attr + train_batch_size=self._batch_size, update_batch_size=update_batch_size) + self._controller.setup(self._dna_spec) + # If you have: + # training batch size N (PG proposes a batch N of models, stored in cache) + # update batch size M, (minibatch update batch size) + # num. of updates P, (how many minibatch updates) + # the update rule is: + # + # for _ in range(P): + # mini_batch = select(M, N) + # train(model, mini_batch) + + def get_state(self): + # TODO: See pyglove policy_gradients generator for previous + # implementations. cl/325886417 + return None + + def set_state(self, serialized_state): # pytype: disable=signature-mismatch # overriding-parameter-count-checks + # TODO: See above. + pass diff --git a/iris/algorithms/controllers/random_controller.py b/iris/algorithms/controllers/random_controller.py new file mode 100644 index 0000000..c81cec6 --- /dev/null +++ b/iris/algorithms/controllers/random_controller.py @@ -0,0 +1,59 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Random Controller that proposes random topologies.""" + +from iris.algorithms.controllers import base_controller +import pyglove as pg + + +class RandomController(base_controller.BaseController): + """Random Search Controller.""" + + def __init__(self, dna_spec: pg.DNASpec, batch_size: int, + **kwargs): + """Initialization. See base class for more details.""" + super().__init__(dna_spec, batch_size) + del kwargs + + def propose_dna(self): + """Proposes a topology dna using stored template. + + Args: None. + + Returns: + dna: A proposed dna. + """ + return pg.random_dna(self._dna_spec) + + def collect_rewards_and_train(self, reward_vector, dna_list): + """Collects rewards and sends them to the replay buffer. + + Args: + reward_vector: list of reward floats. + dna_list: list of dna's from the proposal function. + + Returns: + None. + """ + + del reward_vector + del dna_list + pass + + def get_state(self): + return None + + def set_state(self, serialized_state): # pytype: disable=signature-mismatch # overriding-parameter-count-checks + pass diff --git a/iris/algorithms/controllers/regularized_evolution_controller.py b/iris/algorithms/controllers/regularized_evolution_controller.py new file mode 100644 index 0000000..8b702e9 --- /dev/null +++ b/iris/algorithms/controllers/regularized_evolution_controller.py @@ -0,0 +1,51 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Regularized Evolution controller from PyGlove.""" +# pylint: disable=protected-access +from typing import Optional +from iris.algorithms.controllers import base_controller +import numpy as np +import pyglove as pg + + +class RegularizedEvolutionController(base_controller.BaseController): + """Regularized Evolution Controller.""" + + def __init__(self, + dna_spec: pg.DNASpec, + batch_size: int, + seed: Optional[int] = None, + **kwargs): + """Initialization. See base class for more details.""" + + super().__init__(dna_spec, batch_size) + # Hyperparameters copied from example colab: + # http://pyglove/generators/evolution_example.ipynb + population_size = self._batch_size + tournament_size = int(np.sqrt(population_size)) + + self._controller = pg.evolution.RegularizedEvolution( + population_size=population_size, + tournament_size=tournament_size, + mutator=pg.evolution.mutators.Uniform(seed=seed), + seed=seed) # pytype: disable=wrong-arg-types # gen-stub-imports + self._controller.setup(self._dna_spec) + + def get_state(self): + return pg.to_json_str(self._history) # pytype: disable=attribute-error + + def set_state(self, serialized_state): # pytype: disable=signature-mismatch # overriding-parameter-count-checks + self._history = pg.from_json_str(serialized_state) + self._controller.recover(self._history) diff --git a/iris/algorithms/es_enas_algorithm.py b/iris/algorithms/es_enas_algorithm.py new file mode 100644 index 0000000..8710d01 --- /dev/null +++ b/iris/algorithms/es_enas_algorithm.py @@ -0,0 +1,138 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Algorithm class for ES-ENAS algorithm.""" + +import functools +from multiprocessing import dummy as mp_threads +from typing import Any, Dict, Sequence + +from iris import worker_util +from iris.algorithms import ars_algorithm +from iris.algorithms import controllers +import pyglove as pg + + +class ES_ENAS(ars_algorithm.AugmentedRandomSearch): # pylint: disable=invalid-name + """ES-ENAS algorithm for NAS-related blackbox optimization. + + Adds PyGlove as an additional optimizer for discrete/combinatorial search + spaces, making this a combination of two different algorithms (ARS and PyGlove + controllers). + + At its core logic, mainly appends an extra "dna" (model architecture) to the + AugmentedRandomSearch request. This "dna" is then processed by the PyGlove + controller. + """ + + def __init__(self, + controller_str: str = "regularized_evolution", + dna_proposal_interval: int = 50, + multithreading: bool = False, + **kwargs) -> None: + """Initializes the ES-ENAS algorithm, as well as ARS parent class. + + Args: + controller_str: Which controller algorithm to use on PyGlove side. + dna_proposal_interval: Iteration interval at which to propose new + architectures. + multithreading: Whether to multithread PyGlove DNA serialization. Pool + created after __init__ to avoid Launchpad pickling issues. + **kwargs: Arguments to parent classes (e.g. AugmentedRandomSearch) + """ + super().__init__(**kwargs) + self._interval_counter = 0 + self._dna_proposal_interval = dna_proposal_interval + self._controller_fn = functools.partial( + controllers.CONTROLLER_DICT[controller_str], + batch_size=2 * self._num_suggestions) + + self._multithreading = multithreading + + def initialize(self, state: Dict[str, Any]) -> None: + super().initialize(state) + if self._multithreading: + self._pool = mp_threads.Pool(self._num_suggestions) + + self._dna_spec = pg.from_json_str(state["serialized_dna_spec"]) + self._controller = self._controller_fn(dna_spec=self._dna_spec) + self._interval_counter = 0 + self._evaluated_serialized_dnas = [] + self._evaluated_rewards = [] + + def process_evaluations( + self, eval_results: Sequence[worker_util.EvaluationResult]) -> None: + super().process_evaluations(eval_results) + + eval_metadatas = [] + eval_rewards = [] + for eval_result in eval_results: + if eval_result.metadata: + eval_metadatas.append(eval_result.metadata) + eval_rewards.append(eval_result.value) + + if eval_metadatas: + + def proper_unserialize(metadata: str) -> pg.DNA: + dna = pg.from_json_str(metadata) + # Put back the DNASpec into DNA, since serialization removed it. + dna.use_spec(self._dna_spec) + return dna + + if self._multithreading: + dna_list = self._pool.map(proper_unserialize, eval_metadatas) # pytype:disable=attribute-error + else: + dna_list = map(proper_unserialize, eval_metadatas) + dna_list = list(dna_list) + self._controller.collect_rewards_and_train(eval_rewards, dna_list) + + def get_param_suggestions(self, + evaluate: bool = False) -> Sequence[Dict[str, Any]]: + vanilla_suggestions = super().get_param_suggestions(evaluate) + # Evaluation never calls `process_evaluations`, but we need to update eval + # worker DNAs. + suggest_dna_bool = (self._interval_counter % + self._dna_proposal_interval) == 0 + + if suggest_dna_bool: + # Note that for faster serialization, DNASpec is removed from DNA. + dna_list = [self._controller.propose_dna() for _ in vanilla_suggestions] + if self._multithreading: + metadata_list = self._pool.map(pg.to_json_str, dna_list) # pytype:disable=attribute-error + else: + metadata_list = map(pg.to_json_str, dna_list) + metadata_list = list(metadata_list) + else: + metadata_list = [None] * len(vanilla_suggestions) + + for i, vanilla_suggestion in enumerate(vanilla_suggestions): + vanilla_suggestion["metadata"] = metadata_list[i] + + if not evaluate: + self._interval_counter += 1 + return vanilla_suggestions + + def _get_state(self) -> Dict[str, Any]: + vanilla_state = super()._get_state() + vanilla_state["interval_counter"] = self._interval_counter + vanilla_state["serialized_dna_spec"] = pg.to_json_str(self._dna_spec) + vanilla_state["controller_alg_state"] = self._controller.get_state() + return vanilla_state + + def _set_state(self, new_state: Dict[str, Any]) -> None: + super()._set_state(new_state) # pytype: disable=attribute-error + self._interval_counter = new_state["interval_counter"] + self._dna_spec = pg.from_json_str(new_state["serialized_dna_spec"]) + self._controller = self._controller_fn(dna_spec=self._dna_spec) + self._controller.set_state(new_state["controller_alg_state"]) diff --git a/iris/algorithms/es_enas_algorithm_test.py b/iris/algorithms/es_enas_algorithm_test.py new file mode 100644 index 0000000..73d76e6 --- /dev/null +++ b/iris/algorithms/es_enas_algorithm_test.py @@ -0,0 +1,111 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pytype: disable=attribute-error +from gym import spaces +from iris import worker_util +from iris.algorithms import es_enas_algorithm +from iris.policies import nas_policy +import numpy as np +import pyglove as pg +from absl.testing import absltest +from absl.testing import parameterized + + +def make_init_state(): + policy = nas_policy.NumpyEdgeSparsityPolicy( + ob_space=spaces.Box(low=-10, high=10, shape=(5,)), + ac_space=spaces.Box(low=-10, high=10, shape=(3,)), + hidden_layer_sizes=[16], + hidden_layer_edge_num=[3, 3]) + weights = policy.get_weights() + return { + 'init_params': weights, + 'serialized_dna_spec': pg.to_json_str(policy.dna_spec) + } + + +def make_evaluation_results(suggestion_list): + eval_results = [] + for suggestion in suggestion_list[:-1]: + evaluation_result = worker_util.EvaluationResult( + params_evaluated=suggestion['params_to_eval'], + value=np.random.uniform(), + metadata=suggestion['metadata']) + eval_results.append(evaluation_result) + eval_results.append(worker_util.EvaluationResult(np.empty(0), 0)) # pytype: disable=wrong-arg-types # numpy-scalars + return eval_results + + +class EsEnasAlgorithmTest(parameterized.TestCase): + + def setUp(self): + self.dna_proposal_interval = 1 + self.num_suggestions = 40 + self.step_size = 0.5 + self.std = 1.0 + self.random_seed = 7 + super().setUp() + + @parameterized.named_parameters( + ('hill_climb', 'hill_climb'), ('neat', 'neat'), + ('policy_gradient', 'policy_gradient'), + ('random_search', 'random_search'), + ('regularized_evolution', 'regularized_evolution')) + def test_es_enas_step(self, controller_str): + algo = es_enas_algorithm.ES_ENAS( + controller_str=controller_str, + dna_proposal_interval=self.dna_proposal_interval, + num_suggestions=self.num_suggestions, + step_size=self.step_size, + std=self.std, + random_seed=self.random_seed) + + init_state = make_init_state() + algo.initialize(init_state) + + suggestion_list = algo.get_param_suggestions(evaluate=False) + self.assertEqual(algo._interval_counter, 1) + + eval_results = make_evaluation_results(suggestion_list) + algo.process_evaluations(eval_results) + + # Verifies that algo can keep track of previous evaluations. + current_full_state = algo.state + controller_state = algo._controller.get_state() + + algo.initialize(init_state) + algo.state = current_full_state + self.assertEqual(algo._controller.get_state(), controller_state) + + @parameterized.named_parameters(('False', False), ('True', True)) + def test_multithreading(self, multithreading): + algo = es_enas_algorithm.ES_ENAS( + dna_proposal_interval=self.dna_proposal_interval, + multithreading=multithreading, + num_suggestions=self.num_suggestions, + step_size=self.step_size, + std=self.std, + random_seed=self.random_seed) + + init_state = make_init_state() + algo.initialize(init_state) + + suggestion_list = algo.get_param_suggestions(evaluate=False) + eval_results = make_evaluation_results(suggestion_list) + algo.process_evaluations(eval_results) + + +if __name__ == '__main__': + absltest.main() diff --git a/iris/algorithms/multi_agent_ars_algorithm_test.py b/iris/algorithms/multi_agent_ars_algorithm_test.py new file mode 100644 index 0000000..b9cebb4 --- /dev/null +++ b/iris/algorithms/multi_agent_ars_algorithm_test.py @@ -0,0 +1,500 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from iris import checkpoint_util +from iris import normalizer +from iris import worker_util +from iris.algorithms import ars_algorithm +import numpy as np +import tensorflow as tf +from absl.testing import absltest +from absl.testing import parameterized + + +class AlgorithmTest(tf.test.TestCase, parameterized.TestCase): + + def _init_algo(self, agent_keys=None): + return ars_algorithm.MultiAgentAugmentedRandomSearch( + num_suggestions=4, + step_size=0.5, + std=1., + top_percentage=1, + orthogonal_suggestions=True, + quasirandom_suggestions=False, + top_sort_type='diff', + random_seed=7, + agent_keys=agent_keys) + + @parameterized.parameters( + (None, ['arm', 'opp'], 2), + (['agent_a', 'agent_b', 'agent_c'], ['agent_a', 'agent_b', 'agent_c'], 3), + ) + def test_init(self, agent_keys, expected_agent_keys, expected_num_agents): + algo = self._init_algo(agent_keys) + self.assertListEqual(algo._agent_keys, expected_agent_keys) + self.assertEqual(algo._num_agents, expected_num_agents) + + def _build_evaluation_results(self) -> list[worker_util.EvaluationResult]: + eval_results = [ + worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars + params_evaluated=np.array([10., 11., 12., 13.]), + value=10, + metrics={ + 'reward_arm': 10, + 'reward_opp': -5 + }), + worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars + params_evaluated=np.array([10., 11., 14., 15.]), + value=10, + metrics={ + 'reward_arm': 10, + 'reward_opp': -10 + }), + worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars + params_evaluated=np.empty(0), + value=0, + metrics={ + 'reward_arm': 0, + 'reward_opp': 0 + }), + worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars + params_evaluated=np.array([1., 2., 3., 4.]), + value=10, + metrics={ + 'reward_arm': 10, + 'reward_opp': -10 + }), + worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars + params_evaluated=np.array([10., 11., 12., 13.]), + value=-10, + metrics={ + 'reward_arm': -10, + 'reward_opp': 5 + }), + worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars + params_evaluated=np.array([10., 11., 14., 15.]), + value=-10, + metrics={ + 'reward_arm': -10, + 'reward_opp': 10 + }), + worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars + params_evaluated=np.array([5., 6., 7., 8.]), + value=-10, + metrics={ + 'reward_arm': -10, + 'reward_opp': 10 + }), + worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars + params_evaluated=np.empty(0), + value=0, + metrics={ + 'reward_arm': 0, + 'reward_opp': 0 + }), + ] + return eval_results + + @parameterized.parameters( + (['arm', 'opp'], [[10., 11.], [12., 13.]], 2), + (['1', '2', '3', '4'], [[10.], [11.], [12.], [13.]], 4), + ) + def test_split_params(self, agent_keys, expected_split_params, num_agents): + algo = self._init_algo(agent_keys=agent_keys) + params = np.array([10., 11., 12., 13.]) + split_params = algo._split_params(params) + self.assertLen(split_params, num_agents) + for p, exp_p in zip(split_params, expected_split_params): + np.testing.assert_array_equal(p, np.asarray(exp_p)) + + @parameterized.parameters( + ([np.asarray([10., 11.]), np.asarray([12., 13.])], + np.asarray([10., 11., 12., 13.])), + ([np.asarray([10.]), np.asarray([11.]), + np.asarray([12.]), np.asarray([13.])], + np.asarray([10., 11., 12., 13.])), + ) + def test_combine_params(self, split_params, expected_combined_params): + algo = self._init_algo() + combined_params = algo._combine_params(split_params) + np.testing.assert_array_equal(combined_params, expected_combined_params) + + @parameterized.parameters( + ('arm', np.asarray([10, 10]), np.asarray([-10, -10]), np.asarray([0, 1])), + ('opp', np.asarray([-10, -5]), np.asarray([10, 5]), np.asarray([1, 0])), + ) + def test_get_top_evaluation_results(self, + agent_key, + expected_pos_evals, + expected_neg_evals, + expected_idx): + algo = self._init_algo() + eval_results = self._build_evaluation_results() + filtered_pos_eval_results = eval_results[:2] + filtered_neg_eval_results = eval_results[4:6] + pos_evals, neg_evals, idx = algo._get_top_evaluation_results( + agent_key=agent_key, + pos_eval_results=filtered_pos_eval_results, + neg_eval_results=filtered_neg_eval_results) + np.testing.assert_array_equal(pos_evals, expected_pos_evals) + np.testing.assert_array_equal(neg_evals, expected_neg_evals) + np.testing.assert_array_equal(idx, expected_idx) + + def test_multi_agent_ars_gradient(self): + algo = self._init_algo() + init_state = {'init_params': np.array([10., 10., 10., 10.])} + algo.initialize(init_state) + suggestions = algo.get_param_suggestions() + self.assertLen(suggestions, 8) + eval_results = self._build_evaluation_results() + algo.process_evaluations(eval_results) + np.testing.assert_array_almost_equal( + algo._opt_params, + np.array([10., 11., 6.83772234, 5.88903904])) + + @parameterized.parameters( + ({'params_to_eval': np.asarray([1, 2]), + 'obs_norm_state': None}, + {'params_to_eval': np.asarray([1, 2]), + 'obs_norm_state': None}, + False, 2), + ({'params_to_eval': np.asarray([1, 2]), + 'obs_norm_state': {'mean': np.asarray([3., 4.]), + 'std': np.asarray([5., 6.]), + 'n': 5}}, + {'params_to_eval': np.asarray([1, 2]), + 'obs_norm_state': {'mean': np.asarray([3., 4.]), + 'std': np.asarray([5., 6.]), + 'n': 5}}, + False, 3), + ({'params_to_eval': np.asarray([1, 2,]), + 'obs_norm_state': None}, + {'params_to_eval': np.asarray([1, 2, 1, 2]), + 'obs_norm_state': None}, + True, 2), + ({'params_to_eval': np.asarray([1, 2]), + 'obs_norm_state': {'mean': np.asarray([3., 4.,]), + 'std': np.asarray([5., 6.,]), + 'n': 5}}, + {'params_to_eval': np.asarray([1, 2, 1, 2]), + 'obs_norm_state': {'mean': np.asarray([3., 4., 3., 4.]), + 'std': np.asarray([5., 6., 5., 6.]), + 'n': 5}}, + True, 2), + ({'params_to_eval': np.asarray([1, 2]), + 'obs_norm_state': {'mean': np.asarray([3., 4.]), + 'std': np.asarray([5., 6.]), + 'n': 5}}, + {'params_to_eval': np.asarray([1, 2, 1, 2, 1, 2, 1, 2]), + 'obs_norm_state': {'mean': np.asarray([3., 4., 3., 4., 3., 4., 3., 4.]), + 'std': np.asarray([5., 6., 5., 6., 5., 6., 5., 6.]), + 'n': 5}}, + True, 4), + ({'params_to_eval': np.asarray([1, 2]), + 'obs_norm_state': {'mean': np.asarray([3., 4.]), + 'std': np.asarray([5., 6.]), + 'n': 5}}, + {'params_to_eval': np.asarray([1, 2]), + 'obs_norm_state': {'mean': np.asarray([3., 4.]), + 'std': np.asarray([5., 6.]), + 'n': 5}}, + True, 1), + ) + def test_restore_state_from_checkpoint(self, + state, + expected_state, + restore_state_from_single_agent, + num_agents): + algo = ars_algorithm.MultiAgentAugmentedRandomSearch( + num_suggestions=3, + step_size=0.5, + std=1.0, + top_percentage=1, + orthogonal_suggestions=True, + quasirandom_suggestions=True, + obs_norm_data_buffer=normalizer.MeanStdBuffer() + if state['obs_norm_state'] is not None else None, + agent_keys=[str(i) for i in range(num_agents)], + restore_state_from_single_agent=restore_state_from_single_agent, + random_seed=7, + ) + self.assertEqual(algo._num_agents, num_agents) + init_state = {'init_params': np.array([10., 10.])} + if state['obs_norm_state'] is not None: + init_state['obs_norm_buffer_data'] = {'mean': np.asarray([0., 0.]), + 'std': np.asarray([1., 1.]), + 'n': 0} + algo.initialize(init_state) + + with self.subTest('init-mean'): + self.assertAllClose(np.array(algo._opt_params), + init_state['init_params']) + if state['obs_norm_state'] is not None: + with self.subTest('init-obs-mean'): + self.assertAllClose( + np.asarray(algo._obs_norm_data_buffer.data['mean']), + np.asarray(init_state['obs_norm_buffer_data']['mean'])) + with self.subTest('init-obs-n'): + self.assertAllClose( + np.asarray(algo._obs_norm_data_buffer.data['n']), + np.asarray(init_state['obs_norm_buffer_data']['n'])) + with self.subTest('init-obs-std'): + self.assertAllClose( + np.asarray(algo._obs_norm_data_buffer.data['std']), + init_state['obs_norm_buffer_data']['std']) + + algo.restore_state_from_checkpoint(state) + + self.assertAllClose(algo._opt_params, + expected_state['params_to_eval']) + if expected_state['obs_norm_state'] is not None: + std = expected_state['obs_norm_state']['std'] + var = np.square(std) + expected_unnorm_var = var * 4 + with self.subTest('restore-obs-mean'): + self.assertAllClose( + np.asarray(algo._obs_norm_data_buffer.data['mean']), + np.asarray(expected_state['obs_norm_state']['mean'])) + with self.subTest('restore-obs-n'): + self.assertAllClose( + np.asarray(algo._obs_norm_data_buffer.data['n']), + np.asarray(expected_state['obs_norm_state']['n'])) + with self.subTest('restore-obs-std'): + self.assertAllClose( + np.asarray(algo._obs_norm_data_buffer.data['unnorm_var']), + expected_unnorm_var) + + @parameterized.parameters( + ( + {'params_to_eval': np.asarray([1, 2]), 'obs_norm_state': None}, + [ + {'params_to_eval': np.asarray([1]), 'obs_norm_state': None}, + {'params_to_eval': np.asarray([2]), 'obs_norm_state': None}, + ], + 2, + ), + ( + { + 'params_to_eval': np.asarray([1, 2, 3, 4, 5, 6]), + 'obs_norm_state': None, + }, + [ + {'params_to_eval': np.asarray([1, 2, 3]), 'obs_norm_state': None}, + {'params_to_eval': np.asarray([4, 5, 6]), 'obs_norm_state': None}, + ], + 2, + ), + ( + { + 'params_to_eval': np.asarray([1, 2, 3, 4, 5, 6]), + 'obs_norm_state': {'mean': np.asarray([6., 7., 8., 9.]), + 'std': np.asarray([10., 11., 12., 13.]), + 'n': 5 + }, + }, + [ + {'params_to_eval': np.asarray([1, 2, 3]), + 'obs_norm_state': {'mean': np.asarray([6., 7.]), + 'std': np.asarray([10., 11.]), + 'n': 5 + }, + }, + {'params_to_eval': np.asarray([4, 5, 6]), + 'obs_norm_state': {'mean': np.asarray([8., 9.]), + 'std': np.asarray([12., 13.]), + 'n': 5 + }, + }, + ], + 2, + ), + ( + { + 'params_to_eval': np.asarray([1, 2, 3, 4, 5, 6]), + 'obs_norm_state': {'mean': np.asarray([[6., 7., 8., 9.], + [10., 11., 12., 13.]]), + 'std': np.asarray([[14., 15., 16., 17.], + [18., 19., 20., 21.]]), + 'n': 5 + }, + }, + [ + {'params_to_eval': np.asarray([1, 2, 3]), + 'obs_norm_state': {'mean': np.asarray([[6., 7.,], [10., 11.,]]), + 'std': np.asarray([[14., 15.,], [18., 19.,]]), + 'n': 5 + }, + }, + {'params_to_eval': np.asarray([4, 5, 6]), + 'obs_norm_state': {'mean': np.asarray([[8., 9.], [12., 13.]]), + 'std': np.asarray([[16., 17.], [20., 21.]]), + 'n': 5 + }, + }, + ], + 2, + ), + ( + { + 'params_to_eval': np.asarray([1, 2, 3, 4, 5, 6]), + 'obs_norm_state': None, + }, + [ + {'params_to_eval': np.asarray([1, 2]), 'obs_norm_state': None}, + {'params_to_eval': np.asarray([3, 4]), 'obs_norm_state': None}, + {'params_to_eval': np.asarray([5, 6]), 'obs_norm_state': None}, + ], + 3, + ), + ( + { + 'params_to_eval': np.asarray([1, 2, 3, 4, 5, 6]), + 'obs_norm_state': {'mean': np.asarray( + [[6., 7., 8., 9., 10., 11.], + [12., 13., 14., 15., 16., 17.]]), + 'std': np.asarray( + [[14., 15., 16., 17., 18., 19.], + [20., 21., 22., 23., 24., 25.]]), + 'n': 5 + }, + }, + [ + {'params_to_eval': np.asarray([1, 2]), + 'obs_norm_state': {'mean': np.asarray([[6., 7.,], [12., 13.,]]), + 'std': np.asarray([[14., 15.,], [20., 21.,]]), + 'n': 5 + }, + }, + {'params_to_eval': np.asarray([3, 4]), + 'obs_norm_state': {'mean': np.asarray([[8., 9.], [14., 15.]]), + 'std': np.asarray([[16., 17.], [22., 23.]]), + 'n': 5 + }, + }, + {'params_to_eval': np.asarray([5, 6]), + 'obs_norm_state': {'mean': np.asarray([[10., 11.], [16., 17.]]), + 'std': np.asarray([[18., 19.], [24., 25.]]), + 'n': 5 + }, + }, + ], + 3, + ), + ) + def test_maybe_save_custom_checkpoint(self, + state, + expected_states, + num_agents): + tempdir = self.create_tempdir() + path = 'checkpoint_iteration_0' + full_path = os.path.join(tempdir, path) + algo = ars_algorithm.MultiAgentAugmentedRandomSearch( + num_suggestions=3, + step_size=0.5, + std=1.0, + top_percentage=1, + orthogonal_suggestions=True, + quasirandom_suggestions=True, + obs_norm_data_buffer=normalizer.MeanStdBuffer() + if state['obs_norm_state'] is not None else None, + agent_keys=[str(i) for i in range(num_agents)], + random_seed=7) + algo.maybe_save_custom_checkpoint(state, full_path) + for i in range(num_agents): + agent_checkpoint_path = f'{full_path}_agent_{i}' + agent_state = checkpoint_util.load_checkpoint_state(agent_checkpoint_path) + self.assertAllClose(agent_state['params_to_eval'], + expected_states[i]['params_to_eval']) + if expected_states[i]['obs_norm_state'] is not None: + self.assertAllClose(agent_state['obs_norm_state']['mean'], + expected_states[i]['obs_norm_state']['mean']) + self.assertAllClose(agent_state['obs_norm_state']['std'], + expected_states[i]['obs_norm_state']['std']) + self.assertAllClose(agent_state['obs_norm_state']['n'], + expected_states[i]['obs_norm_state']['n']) + + def test_split_checkpoint(self): + tempdir = self.create_tempdir() + path = 'checkpoint_iteration_0' + full_path = os.path.join(tempdir, path) + algo = ars_algorithm.MultiAgentAugmentedRandomSearch( + num_suggestions=3, + step_size=0.5, + std=1.0, + top_percentage=1, + orthogonal_suggestions=True, + quasirandom_suggestions=True, + obs_norm_data_buffer=normalizer.MeanStdBuffer(), + agent_keys=[str(i) for i in range(3)], + random_seed=7) + state = { + 'params_to_eval': np.asarray([1, 2, 3, 4, 5, 6]), + 'obs_norm_state': { + 'mean': np.asarray([ + [6.0, 7.0, 8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0, 16.0, 17.0], + ]), + 'std': np.asarray([ + [14.0, 15.0, 16.0, 17.0, 18.0, 19.0], + [20.0, 21.0, 22.0, 23.0, 24.0, 25.0], + ]), + 'n': 5, + }, + } + expected_states = [ + { + 'params_to_eval': np.asarray([1, 2]), + 'obs_norm_state': { + 'mean': np.asarray([[6.0, 7.0], [12.0, 13.0]]), + 'std': np.asarray([[14.0, 15.0], [20.0, 21.0]]), + 'n': 5, + }, + }, + { + 'params_to_eval': np.asarray([3, 4]), + 'obs_norm_state': { + 'mean': np.asarray([[8.0, 9.0], [14.0, 15.0]]), + 'std': np.asarray([[16.0, 17.0], [22.0, 23.0]]), + 'n': 5, + }, + }, + { + 'params_to_eval': np.asarray([5, 6]), + 'obs_norm_state': { + 'mean': np.asarray([[10.0, 11.0], [16.0, 17.0]]), + 'std': np.asarray([[18.0, 19.0], [24.0, 25.0]]), + 'n': 5, + }, + }] + algo.save_checkpoint_internal( + full_path, state + algo.save_checkpoint_oss(full_path, state) + algo.split_and_save_checkpoint(checkpoint_path=full_path) + for i in range(3): + agent_checkpoint_path = f'{full_path}_agent_{i}' + agent_state = checkpoint_util.load_checkpoint_state(agent_checkpoint_path) + self.assertAllClose(agent_state['params_to_eval'], + expected_states[i]['params_to_eval']) + if expected_states[i]['obs_norm_state'] is not None: + self.assertAllClose(agent_state['obs_norm_state']['mean'], + expected_states[i]['obs_norm_state']['mean']) + self.assertAllClose(agent_state['obs_norm_state']['std'], + expected_states[i]['obs_norm_state']['std']) + self.assertAllClose(agent_state['obs_norm_state']['n'], + expected_states[i]['obs_norm_state']['n']) + + +if __name__ == '__main__': + absltest.main() diff --git a/iris/algorithms/optimizers.py b/iris/algorithms/optimizers.py new file mode 100644 index 0000000..cfbdddc --- /dev/null +++ b/iris/algorithms/optimizers.py @@ -0,0 +1,165 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Library for reconstructing Jacobian J \in R^{m x n}. + + +A library for reconstructing the approximate version of the Jacobian matrix from +the set of local noisy linear measurements. +""" + +# pylint: disable=invalid-name, missing-function-docstring, line-too-long + +import cvxpy as cp +import numpy as np + + +def general_regularized_regression_loss(A, b, x, regularization_parameter, + loss_norm, regularization_norm): + r"""Function implementing general regularized regression loss. + + Implements general regularized regression objective function. + The optimization problem is defined as follows: + min_x ||A*x - b||^{2}_{p} + + regularization_parameter*||x||_{q} + + where: p = loss_norm, q = regularization_norm. + + Args: + A: see the definition of the optimization problem above + b: see the definition of the optimization problem above + x: see the definition of the optimization problem above + regularization_parameter: see description of the optimization problem above + loss_norm: see the definition of the optimization problem above + regularization_norm: see the definition of the optimization problem above + + Returns: + The general regularized loss function. + """ + + def loss_fn(A, b, x): + k = len(A) + b_reshaped = b.reshape((k)) + return cp.pnorm(cp.matmul(A, x) - b_reshaped, p=loss_norm)**loss_norm + + def regularizer(x): + return cp.pnorm(x, p=regularization_norm)**regularization_norm + + return loss_fn(A, b, x) + regularization_parameter * regularizer(x) + + +def vector_decoding_function(A, b, optimization_parameters, loss_function): + r"""Function decoding a vector from the set of liner mesurements. + + Decodes the vector from the set of linear measurements along the directions + defined by the rows of matrix A. Linear measurements are encoded by a + vector b. The decoding is done by minimizing a specific loss_function + parametrized by . + + Args: + A: matrix with rows defining directions along which vector to be + reconstructed is being sensed + b: vector of linear measurements obtained by the above sensing + optimization_parameters: see description of the optimization problem above + loss_function: see description of the optimization problem above + + Returns: + The approximate recovered vector x \in R^{n} + """ + n = len(np.transpose(A)) + x = cp.Variable(n) + regularization_parameter = cp.Parameter(nonneg=True) + regularization_parameter.value = optimization_parameters + problem = cp.Problem( + cp.Minimize(loss_function(A, b, x, regularization_parameter))) + problem.solve(solver=cp.ECOS) + result = x.value + res_list = [] + for i in range(n): + res_list.append(result[i]) + return np.array(res_list) + + +def general_jacobian_decoder(atranspose, yprime, optimization_parameters, + loss_function): + r"""The wrapper around gradient decoder that serves to decode entire Jacobian. + + + Decodes the rows of the Jacobian matrix J \in R^{m x n} and then puts them + together to reconstruct the entire Jacobian. Each row r_{i} is reconstructed + from the collection of the noisy dot products: a_{j}^{T}r \sim y_{j}, + where a_{j} is the jth column of the matrix of samples atranspose and y_{j} is + the jth row of the matrix of measurements yprime \in R^{k x m}. + The reconstruction of each row is handled by function by a certain decoding + function that uses parameters defined in and + which goal is to minimize a specific loss_function parametrized by this set + of parameters. + + Args: + atranspose: see description of the optimization problem above + yprime: see description of the optimization problem above + optimization_parameters: see description of the optimization problem above + loss_function: see description of the optimization problem above + + Returns: + The approximate Jacobian J \in R^{m x n} + """ + n = len(atranspose) + m = len(np.transpose(yprime)) + k = len(atranspose[0]) + + final_solutions = [] + for i in range(m): + yprime_row = (np.transpose(yprime))[i] + yprime_row_reshaped = (yprime_row.reshape((k, 1))).astype(np.double) + amatrix = (np.transpose(atranspose)).astype(np.double) + res = vector_decoding_function(amatrix, yprime_row_reshaped, + optimization_parameters, loss_function) + list_res = [] + for j in range(n): + list_res.append(res[j]) + final_solutions.append(np.float32(list_res)) + return np.array(final_solutions) + + +def l1_regression_loss(A, b, x, regularization_parameter): + del regularization_parameter + return general_regularized_regression_loss(A, b, x, 0.0, 1, 1) + + +def l1_jacobian_decoder(atranspose, yprime, optimization_parameters): + return general_jacobian_decoder(atranspose, yprime, optimization_parameters, + l1_regression_loss) + + +def lasso_regression_loss(A, b, x, regularization_parameter): + return general_regularized_regression_loss(A, b, x, regularization_parameter, + 2, 1) + + +def lasso_regression_jacobian_decoder(atranspose, yprime, + optimization_parameters): + return general_jacobian_decoder(atranspose, yprime, optimization_parameters, + lasso_regression_loss) + + +def ridge_regression_loss(A, b, x, regularization_parameter): + return general_regularized_regression_loss(A, b, x, regularization_parameter, + 2, 2) + + +def ridge_regression_jacobian_decoder(atranspose, yprime, + optimization_parameters): + return general_jacobian_decoder(atranspose, yprime, optimization_parameters, + ridge_regression_loss) diff --git a/iris/algorithms/pes_algorithm.py b/iris/algorithms/pes_algorithm.py new file mode 100644 index 0000000..f8adc76 --- /dev/null +++ b/iris/algorithms/pes_algorithm.py @@ -0,0 +1,223 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Algorithm class for Augmented Random Search Blackbox algorithm.""" + +from typing import Any, Dict, Optional, Sequence + +from iris import normalizer +from iris import worker_util +from iris.algorithms import algorithm +from iris.algorithms import stateless_perturbation_generators +import numpy as np + + +class PersistentES(algorithm.BlackboxAlgorithm): + """Augmented random search algorithm for blackbox optimization.""" + + def __init__(self, + std: float, + step_size: float, + top_percentage: float = 1.0, + orthogonal_suggestions: bool = False, + quasirandom_suggestions: bool = False, + top_sort_type: str = "max", + obs_norm_data_buffer: Optional[normalizer.MeanStdBuffer] = None, + partial_rollout_length: Optional[int] = 5, + **kwargs) -> None: + """Initializes the augmented random search algorithm. + + Args: + std: Standard deviation for normal perturbations around current + optimization parameter vector. + step_size: Step size for gradient ascent. + top_percentage: Fraction of top performing perturbations to use for + gradient estimation. + orthogonal_suggestions: Whether to orthogonalize the perturbations. + quasirandom_suggestions: Whether quasirandom perturbations should be used; + valid only if orthogonal_suggestions = True. + top_sort_type: How to sort evaluation results for selecting top + directions. Valid options are: "max" and "diff". + obs_norm_data_buffer: Buffer to sync statistics from all workers for + online mean std observation normalizer. + partial_rollout_length: Partial environment rollout length. + **kwargs: Other keyword arguments for base class. + """ + super().__init__(**kwargs) + self._std = std + self._step_size = step_size + self._num_top = int(top_percentage * self._num_suggestions) + self._num_top = max(1, self._num_top) + self._orthogonal_suggestions = orthogonal_suggestions + self._quasirandom_suggestions = quasirandom_suggestions + self._top_sort_type = top_sort_type + self._obs_norm_data_buffer = obs_norm_data_buffer + self._partial_rollout_length = partial_rollout_length + + def initialize(self, state: Dict[str, Any]) -> None: + """Initializes the algorithm from initial worker state.""" + self._opt_params = state["init_params"] + self._positive_cumulative_perturbations = [0] * self._num_suggestions + self._negative_cumulative_perturbations = [0] * self._num_suggestions + + # Initialize Observation normalization buffer with init data from the worker + if self._obs_norm_data_buffer is not None: + self._obs_norm_data_buffer.data = state["obs_norm_buffer_data"] + + def process_evaluations( + self, eval_results: Sequence[worker_util.EvaluationResult]) -> None: + """Processes the list of Blackbox function evaluations return from workers. + + Gradient is computed by taking a weighted sum of directions and + difference of their value from the current value. The current parameter + vector is then updated in the gradient direction with specified step size. + + Args: + eval_results: List containing Blackbox function evaluations based on the + order in which the suggestions were sent. ARS performs antithetic + gradient estimation. The suggestions are sent for evaluation in pairs. + The eval_results list should contain an even number of entries with the + first half entries corresponding to evaluation result of positive + perturbations and the last half corresponding to negative perturbations. + """ + + # Retrieve delta direction from the param suggestion sent for evaluation. + pos_eval_results = eval_results[:self._num_suggestions] + neg_eval_results = eval_results[self._num_suggestions:] + filtered_pos_eval_results = [] + filtered_neg_eval_results = [] + pos_directions = [] + neg_directions = [] + for i in range(len(pos_eval_results)): + if (pos_eval_results[i].params_evaluated.size) and ( + neg_eval_results[i].params_evaluated.size): + filtered_pos_eval_results.append(pos_eval_results[i]) + filtered_neg_eval_results.append(neg_eval_results[i]) + + params = pos_eval_results[i].params_evaluated + pos_directions.append((params - self._opt_params) / self._std) + pos_directions[-1] = self._positive_cumulative_perturbations[ + i] + pos_directions[-1] + if pos_eval_results[i].metrics["current_step"] == 0: + self._positive_cumulative_perturbations[i] = 0 + else: + self._positive_cumulative_perturbations[i] = pos_directions[-1] + + params = neg_eval_results[i].params_evaluated + neg_directions.append((params - self._opt_params) / self._std) + neg_directions[-1] = self._negative_cumulative_perturbations[ + i] + neg_directions[-1] + if neg_eval_results[i].metrics["current_step"] == 0: + self._negative_cumulative_perturbations[i] = 0 + else: + self._negative_cumulative_perturbations[i] = neg_directions[-1] + + pos_directions = np.array(pos_directions) + neg_directions = np.array(neg_directions) + eval_results = filtered_pos_eval_results + filtered_neg_eval_results + + # Get top evaluation results + pos_evals = np.array([r.value for r in filtered_pos_eval_results]) + neg_evals = np.array([r.value for r in filtered_neg_eval_results]) + if self._top_sort_type == "max": + max_evals = np.max(np.vstack([pos_evals, neg_evals]), axis=0) + elif self._top_sort_type == "diff": + max_evals = np.abs(pos_evals - neg_evals) + idx = (-max_evals).argsort()[:self._num_top] + pos_evals = pos_evals[idx] + neg_evals = neg_evals[idx] + all_top_evals = np.hstack([pos_evals, neg_evals]) + + # Get delta directions corresponding to top evals + pos_directions = pos_directions[idx, :] + neg_directions = neg_directions[idx, :] + + # Estimate gradients + gradient = (np.dot(pos_evals, pos_directions) + + np.dot(neg_evals, neg_directions)) / pos_evals.shape[0] + if not np.isclose(np.std(all_top_evals), 0.0): + gradient /= np.std(all_top_evals) + + # Apply gradients + self._opt_params += self._step_size * gradient + + # Update the observation buffer + if self._obs_norm_data_buffer is not None: + for r in eval_results: + self._obs_norm_data_buffer.merge(r.obs_norm_buffer_data) + + def get_param_suggestions(self, + evaluate: bool = False) -> Sequence[Dict[str, Any]]: + """Suggests a list of inputs to evaluate the Blackbox function on. + + Suggestions are sampled from a gaussian distribution around the current + parameter vector. For each suggestion, a dict containing keyword arguments + for the worker is sent. + + Args: + evaluate: Whether to evaluate current optimization variables for reporting + training progress. + + Returns: + A list of suggested inputs for the workers to evaluate. + """ + if evaluate: + param_suggestions = [self._opt_params] * self._num_evals + else: + dimensions = self._opt_params.shape[0] + param_suggestions = self._np_random_state.normal( + 0, 1, (self._num_suggestions, dimensions)) + if self._orthogonal_suggestions: + if self._quasirandom_suggestions: + param_suggestions = stateless_perturbation_generators.RandomHadamardMatrixGenerator( + self._num_suggestions, dimensions).generate_matrix() + else: + ortho_matrix, _ = np.linalg.qr(param_suggestions.T) + param_suggestions = np.sqrt(dimensions) * ortho_matrix.T + param_suggestions = np.vstack([ + self._opt_params + self._std * param_suggestions, + self._opt_params - self._std * param_suggestions + ]) + + suggestions = [] + for params in param_suggestions: + suggestion = {"params_to_eval": params} + if evaluate: + suggestion["partial_rollout_length"] = None + else: + suggestion["partial_rollout_length"] = self._partial_rollout_length + if self._obs_norm_data_buffer is not None: + suggestion["obs_norm_state"] = self._obs_norm_data_buffer.state + suggestion["update_obs_norm_buffer"] = not evaluate + suggestions.append(suggestion) + return suggestions + + @property + def state(self) -> Dict[str, Any]: + return self._get_state() + + def _get_state(self) -> Dict[str, Any]: + state = {"params_to_eval": self._opt_params} + if self._obs_norm_data_buffer is not None: + state["obs_norm_state"] = self._obs_norm_data_buffer.state + return state + + @state.setter + def state(self, new_state: Dict[str, Any]) -> None: + self._set_state(new_state) + + def _set_state(self, new_state: Dict[str, Any]) -> None: + self._opt_params = new_state["params_to_eval"] + if self._obs_norm_data_buffer is not None: + self._obs_norm_data_buffer.state = new_state["obs_norm_state"] diff --git a/iris/algorithms/pes_algorithm_test.py b/iris/algorithms/pes_algorithm_test.py new file mode 100644 index 0000000..2bce9d9 --- /dev/null +++ b/iris/algorithms/pes_algorithm_test.py @@ -0,0 +1,58 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from iris import worker_util +from iris.algorithms import pes_algorithm +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + + +class AlgorithmTest(parameterized.TestCase): + + @parameterized.parameters( + (True, False), + (False, False), + ) + def test_pes_gradient(self, orthogonal_suggestions, quasirandom_suggestions): + algo = pes_algorithm.PersistentES( + num_suggestions=3, + step_size=0.5, + std=1., + top_percentage=1, + orthogonal_suggestions=orthogonal_suggestions, + quasirandom_suggestions=quasirandom_suggestions, + random_seed=7) + init_state = {'init_params': np.array([10., 10.])} + algo.initialize(init_state) + eval_results = [ + worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars + np.array([10., 11.]), 10, metrics={'current_step': 5}), + worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars + np.empty(0), 0, metrics={'current_step': 5}), + worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars + np.array([10., 11.]), 10, metrics={'current_step': 5}), + worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars + np.array([10., 9.]), -10, metrics={'current_step': 5}), + worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars + np.array([10., 9.]), -10, metrics={'current_step': 5}), + worker_util.EvaluationResult( # pytype: disable=wrong-arg-types # numpy-scalars + np.empty(0), 0, metrics={'current_step': 5}), + ] + algo.process_evaluations(eval_results) + np.testing.assert_array_equal(algo._opt_params, np.array([10, 11])) + + +if __name__ == '__main__': + absltest.main() diff --git a/iris/algorithms/piars_algorithm.py b/iris/algorithms/piars_algorithm.py new file mode 100644 index 0000000..94cfbe6 --- /dev/null +++ b/iris/algorithms/piars_algorithm.py @@ -0,0 +1,487 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Algorithm class for Predictive Information Augmented Random Search.""" + +from typing import Any, Callable, Dict, Optional, Sequence, Union + +from absl import logging +import gym +from gym import spaces +from gym.spaces import utils +from iris import worker_util +from iris.algorithms import ars_algorithm +from iris.policies import keras_pi_policy +import numpy as np +import tensorflow as tf +from tf_agents.agents.categorical_dqn import categorical_dqn_agent +from tf_agents.environments import gym_wrapper +from tf_agents.replay_buffers import reverb_replay_buffer +from tf_agents.specs import tensor_spec +from tf_agents.trajectories import policy_step +from tf_agents.trajectories import time_step as ts +from tf_agents.trajectories import trajectory +from tf_agents.utils import composite +from tf_agents.utils import eager_utils + + +class PIARS(ars_algorithm.AugmentedRandomSearch): + """Augmented random search on predictive representations.""" + + def __init__( + self, + env: Union[gym.Env, Callable[[], gym.Env]], + policy: Union[ + keras_pi_policy.KerasPIPolicy, + Callable[..., keras_pi_policy.KerasPIPolicy], + ], + env_args: Optional[Dict[str, Any]] = None, + policy_args: Optional[Dict[str, Any]] = None, + learn_representation: bool = True, + **kwargs + ) -> None: + """Initializes the augmented random search algorithm. + + Args: + env: Gym RL environment object to run rollout with. + policy: Policy object to map observations to actions. + env_args: Arguments for env constructor. + policy_args: Arguments for policy constructor. + learn_representation: Whether to learn representation. + **kwargs: Other keyword arguments for base class. + """ + super().__init__(**kwargs) + self._env = env + self._env_args = env_args + self._policy = policy + self._policy_args = policy_args + self.learn_representation = learn_representation + self._representation_params = np.empty(0) + self._representation_learner = RepresentationLearner( + self._env, + self._env_args, + self._policy, + self._policy_args, + "dummy_reverb_server_addr", + ) + + def initialize(self, state: Dict[str, Any]) -> None: + """Initializes the algorithm from initial worker state.""" + super().initialize(state) + self._representation_params = state["init_representation_params"] + + # Initialize representation learner + if self.learn_representation: + self._representation_learner = RepresentationLearner( + self._env, + self._env_args, + self._policy, + self._policy_args, + state["reverb_server_addr"], + ) + self._representation_learner.policy.update_weights(state["init_params"]) + self._representation_learner.policy.update_representation_weights( + state["init_representation_params"] + ) + self._representation_learner.policy.reset() + + def process_evaluations( + self, eval_results: Sequence[worker_util.EvaluationResult] + ) -> None: + """Processes the list of Blackbox function evaluations return from workers. + + Gradient is computed by taking a weighted sum of directions and + difference of their value from the current value. The current parameter + vector is then updated in the gradient direction with specified step size. + + Args: + eval_results: List containing Blackbox function evaluations based on the + order in which the suggestions were sent. ARS performs antithetic + gradient estimation. The suggestions are sent for evaluation in pairs. + The eval_results list should contain an even number of entries with the + first half entries corresponding to evaluation result of positive + perturbations and the last half corresponding to negative perturbations. + """ + super().process_evaluations(eval_results) + + # Train representations + obs_norm_state = None + if self._obs_norm_data_buffer is not None: + obs_norm_state = self.state["obs_norm_state"] + if self.learn_representation: + self._representation_learner.train(obs_norm_state) + + def get_param_suggestions( + self, evaluate: bool = False + ) -> Sequence[Dict[str, Any]]: + """Suggests a list of inputs to evaluate the Blackbox function on. + + Suggestions are sampled from a gaussian distribution around the current + parameter vector. For each suggestion, a dict containing keyword arguments + for the worker is sent. + + Args: + evaluate: Whether to evaluate current optimization variables for reporting + training progress. + + Returns: + A list of suggested inputs for the workers to evaluate. + """ + suggestions = super().get_param_suggestions(evaluate) + + # Pull representation weights from representation learner + if self.learn_representation: + self._representation_params = ( + self._representation_learner.policy.get_representation_weights() + ) + for suggestion in suggestions: + suggestion["representation_params"] = self._representation_params + + return suggestions + + def _get_state(self) -> Dict[str, Any]: + state = super()._get_state() + if self.learn_representation: + state["representation_params"] = self._representation_params + return state + + def _set_state(self, new_state: Dict[str, Any]) -> None: + super()._set_state(new_state) + if self.learn_representation: + self._representation_params = new_state["representation_params"] + # Set policy weights in representation learner + self._representation_learner.policy.update_weights( + new_state["params_to_eval"] + ) + self._representation_learner.policy.update_representation_weights( + new_state["representation_params"] + ) + self._representation_learner.policy.reset() + + +class RepresentationLearner(object): + """Representation learner.""" + + def __init__( + self, + env, + env_args, + policy, + policy_args, + reverb_server_address, + rollout_length=5, + batch_size=512, + weight_decay=1e-5, + gamma=0.99, + num_supports=51, + min_support=-10.0, + max_support=10.0, + use_pi_loss=True, + use_imitation_loss=False, + use_value_loss=False, + learning_rate=1e-4, + grad_clip=0.5, + grad_step=2, + ): + env_args = {} if env_args is None else env_args + policy_args = {} if policy_args is None else policy_args + self._env = env(**env_args) if not isinstance(env, gym.Env) else env + + if not isinstance(policy, keras_pi_policy.KerasPIPolicy): + self.policy = policy( + ob_space=self._env.observation_space, + ac_space=self._env.action_space, + **policy_args + ) + else: + self.policy = policy + + obs_spec = gym_wrapper.spec_from_gym_space(self._env.observation_space) + action_spec = gym_wrapper.spec_from_gym_space(self._env.action_space) + time_step_spec = ts.time_step_spec(observation_spec=obs_spec) + policy_step_spec = policy_step.PolicyStep(action=action_spec) + collect_data_spec = trajectory.from_transition( + time_step_spec, policy_step_spec, time_step_spec + ) + collect_data_spec = tensor_spec.from_spec(collect_data_spec) + self.reverb_rb = reverb_replay_buffer.ReverbReplayBuffer( + collect_data_spec, + sequence_length=(rollout_length + 1), + table_name="uniform_table", + server_address=reverb_server_address, + ) + self._dataset = self.reverb_rb.as_dataset( + sample_batch_size=batch_size, num_steps=(rollout_length + 1) + ) + self._data_iter = iter(self._dataset) + # For distributional value function + self._num_supports = num_supports + self._min_support = min_support + self._max_support = max_support + self.supports = tf.linspace(min_support, max_support, num_supports) + + self._optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) + self._rollout_length = rollout_length + self._gamma = gamma + self._weight_decay = weight_decay + self.use_pi_loss = use_pi_loss + self.use_imitation_loss = use_imitation_loss + self.use_value_loss = use_value_loss + self.grad_clip = grad_clip + self.grad_step = grad_step + # TODO: Checkpoint globel step + self.global_step = 0 + self.reverb_checkpoint_period = 20 + + def train(self, obs_norm_state=None): + """Train representation from replay data. + + Args: + obs_norm_state: Observation normalizer state (mean and std). + """ + + for _ in range(self.grad_step): + traj, _ = next(self._data_iter) + discount = traj.discount + reward = tf.nest.map_structure( + lambda t: composite.slice_to(t, axis=1, end=-1), traj.reward + ) + action = tf.nest.map_structure( + lambda t: composite.slice_to(t, axis=1, end=-1), traj.action + ) + discount = tf.nest.map_structure( + lambda t: composite.slice_to(t, axis=1, end=-1), discount + ) + + obs = traj.observation + # observation normalization + if obs_norm_state is not None: + obs_mean = obs_norm_state["mean"] + obs_mean = utils.unflatten(self._env.observation_space, obs_mean) + obs_std = obs_norm_state["std"] + obs_std = utils.unflatten(self._env.observation_space, obs_std) + obs = tf.nest.map_structure(lambda x, y: x - y, obs, obs_mean) + obs = tf.nest.map_structure(lambda x, y: x / (y + 1e-8), obs, obs_std) + + # Separate vision input and other observations. + obs_flat = [] + for image_label in self.policy._image_input_labels: # pylint: disable=protected-access + vision_input = obs[image_label] + obs_flat.append(vision_input) + + other_ob = obs.copy() + for image_label in self.policy._image_input_labels: # pylint: disable=protected-access + del other_ob[image_label] + + # Flatten other observations. + other_input = flatten_nested(self.policy._other_ob_space, other_ob) # pylint: disable=protected-access + obs_flat.append(other_input) + + loss, _ = self.train_single_step(obs_flat, reward, action, discount) + self.global_step += 1 + if self.global_step % self.reverb_checkpoint_period == 0: + logging.info("Start checkpointing reverb data.") + self.reverb_rb.py_client.checkpoint() + print("train/loss: {}".format(np.mean(loss.numpy()))) + + @tf.function + def train_single_step(self, obs, reward, action, discount): + """One gradient step.""" + trainable_variables = self.policy.h_model.trainable_weights + trainable_variables += self.policy.f_model.trainable_weights + trainable_variables += self.policy.g_model.trainable_weights + if self.use_pi_loss: + trainable_variables += self.policy.px_model.trainable_weights + trainable_variables += self.policy.py_model.trainable_weights + with tf.GradientTape(watch_accessed_variables=False) as tape: + tape.watch(trainable_variables) + loss, metrics = self.loss(obs, reward, action, discount) + loss_reg = ( + tf.add_n([tf.nn.l2_loss(v) for v in trainable_variables]) + * self._weight_decay + ) + loss += loss_reg + tf.debugging.check_numerics(loss, "loss is inf or nan.") + grads = tape.gradient(loss, trainable_variables) + grads_and_vars = list(zip(grads, trainable_variables)) + if self.grad_clip is not None: + grads_and_vars = eager_utils.clip_gradient_norms( + grads_and_vars, self.grad_clip + ) + self._optimizer.apply_gradients(grads_and_vars) + return loss, metrics + + @tf.function + def rollout(self, obs, actions): + """Latent rollout.""" + s, _ = self.policy.h_model(obs) + outputs = [] + for i in range(self._rollout_length): + p, v = self.policy.f_model(s) + u_next, s_next = self.policy.g_model([s, actions[:, i, ...]]) + outputs.append((p, v, u_next, s)) + s = s_next + p, v = self.policy.f_model(s) + outputs.append((p, v, None, s)) + return outputs + + @tf.function + def loss(self, obs, rewards, actions, discount): + """Representation and dynamics loss.""" + # Ex. obs: [(B, T, 24, 32, 1), (B, T, 24, 32, 1), (B, T, 68)] + # obs0: [(B, 24, 32, 1), (B, 24, 32, 1), (B, 68)] + obs0 = tf.nest.map_structure( + lambda t: composite.slice_to(t, axis=1, end=1), obs + ) + obs0 = [tf.squeeze(x, axis=1) for x in obs0] + latent_traj = self.rollout(obs0, actions) + + loss_pi = 0.0 # Predictive Information Loss + loss_p = 0.0 # Imitation Loss + loss_v = 0.0 # Value loss + loss_r = 0.0 # Reward loss + + def infonce(hidden_x, hidden_y, temperature=0.1): + hidden_x = tf.math.l2_normalize(hidden_x, -1) + hidden_y = tf.math.l2_normalize(hidden_y, -1) + batch_size = tf.shape(hidden_x)[0] + labels = tf.one_hot(tf.range(batch_size), batch_size) + logits = tf.matmul(hidden_x, hidden_y, transpose_b=True) / temperature + hyz = tf.nn.softmax_cross_entropy_with_logits(labels, logits) + iyz = tf.math.log(tf.cast(batch_size, tf.float32)) - hyz + return iyz, logits, labels + + if self.use_pi_loss: + k = self._rollout_length + obs_k = [x[:, k, ...] for x in obs] + # Latent state (from visual + other observations) for the first time step + hx = latent_traj[0][-1] + # Latent state (from visual observations) for the last time step + _, hy_vision = self.policy.h_model(obs_k) + # A trick from https://arxiv.org/abs/2011.10566 + hy_vision = tf.stop_gradient(hy_vision) + zx = self.policy.px_model(hx) + zy = self.policy.py_model(hy_vision) + iyz, _, _ = infonce(zx, zy, temperature=0.1) + loss_pi = -iyz + + # Compute target values + if self.use_value_loss: + # Value distribution for the last time step + last_value_distribution = latent_traj[-1][1] + target_value_supports = [self.supports] # not used in loss + for i in range(self._rollout_length - 2, -1, -1): + r_next = rewards[:, i : i + 1, ...] + d_next = discount[:, i : i + 1, ...] + target_support = ( + target_value_supports[-1] * d_next * self._gamma + r_next + ) + target_value_supports.append(target_support) + target_value_supports.reverse() + + vd = tf.nn.softmax(latent_traj[-1][1]) + pred_value_sum = tf.reduce_sum(vd * self.supports[None, ...], axis=-1) + + for i in range(self._rollout_length - 1): + p = latent_traj[i][0] + z = latent_traj[i][1] + u_next = latent_traj[i][2] + + if self.use_imitation_loss: + loss_p += tf.reduce_sum( + tf.math.square(p - tf.stop_gradient(actions[:, i, ...])), -1 + ) + if self.use_value_loss: + loss_v += self.distributional_value_loss( + value_logits=z, + value_supports=self.supports, + target_value_logits=last_value_distribution, + target_value_supports=target_value_supports[i], + ) + vd = tf.nn.softmax(z) + pred_value_sum += tf.reduce_sum(vd * self.supports[None, ...], axis=-1) + # reward loss + loss_r += tf.reduce_sum( + tf.math.square(u_next - tf.stop_gradient(rewards[:, i : i + 1])), -1 + ) + loss = loss_r + loss_v + loss_p + loss_pi + loss = tf.reduce_mean(loss) + metrics = { + "loss_r": tf.reduce_mean(loss_r), + "loss_v": tf.reduce_mean(loss_v), + "loss_p": tf.reduce_mean(loss_p), + "loss_pi": tf.reduce_mean(loss_pi), + } + if self.use_value_loss: + metrics["value"] = tf.reduce_mean(pred_value_sum) / self._rollout_length + if self.use_pi_loss: + metrics["iyz"] = tf.reduce_mean(iyz) + return loss, metrics + + def distributional_value_loss( + self, + value_logits, + value_supports, + target_value_logits, + target_value_supports, + ): + """Computes the distributional value loss.""" + target_value_probs = tf.nn.softmax(target_value_logits) + projected_target_probs = tf.stop_gradient( + categorical_dqn_agent.project_distribution( + target_value_supports, target_value_probs, value_supports + ) + ) + + value_loss = tf.nn.softmax_cross_entropy_with_logits( + logits=value_logits, labels=projected_target_probs + ) + + return value_loss + + +def flatten_nested(space, x): + """Flatten nested.""" + if isinstance(space, spaces.Box): + x = np.asarray(x, dtype=np.float32) + inner_dims = list(space.shape) + outer_dims = list(x.shape)[: -len(inner_dims)] + x = np.reshape(x, outer_dims + [np.prod(inner_dims)]) + return x + elif isinstance(space, spaces.Tuple): + return np.concatenate( + [flatten_nested(s, x_part) for x_part, s in zip(x, space.spaces)], + axis=-1, + ) + elif isinstance(space, spaces.Dict): + return np.concatenate( + [flatten_nested(space.spaces[key], item) for key, item in x.items()], + axis=-1, + ) + elif isinstance(space, spaces.MultiBinary): + x = np.asarray(x) + space = np.asarray(space) + inner_dims = list(space.shape) + outer_dims = list(x.shape)[: -len(inner_dims)] + x = np.reshape(x, outer_dims + [np.prod(inner_dims)]) + return x + elif isinstance(space, spaces.MultiDiscrete): + x = np.asarray(x) + space = np.asarray(space) + inner_dims = list(space.shape) + outer_dims = list(x.shape)[: -len(inner_dims)] + x = np.reshape(x, outer_dims + [np.prod(inner_dims)]) + return x + else: + raise NotImplementedError diff --git a/iris/algorithms/piars_algorithm_test.py b/iris/algorithms/piars_algorithm_test.py new file mode 100644 index 0000000..58b9865 --- /dev/null +++ b/iris/algorithms/piars_algorithm_test.py @@ -0,0 +1,74 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gym +from iris import worker_util +from iris.algorithms import piars_algorithm +from iris.policies import keras_pi_policy +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + + +class AlgorithmTest(parameterized.TestCase): + + @parameterized.parameters( + (True, False), + (False, False), + ) + def test_ars_gradient(self, orthogonal_suggestions, quasirandom_suggestions): + env = gym.make(id='Pendulum-v0') + policy = keras_pi_policy.KerasPIPolicy( + ob_space=env.observation_space, + ac_space=env.action_space, + state_dim=2, + conv_filter_sizes=(2,), + conv_kernel_sizes=(2,), + image_feature_length=2, + fc_layer_sizes=(2,), + h_fc_layer_sizes=(2,), + f_fc_layer_sizes=(2,), + g_fc_layer_sizes=(2,), + ) + algo = piars_algorithm.PIARS( + num_suggestions=3, + step_size=0.5, + std=1.0, + top_percentage=1, + orthogonal_suggestions=orthogonal_suggestions, + quasirandom_suggestions=quasirandom_suggestions, + env=env, + policy=policy, + learn_representation=False, + random_seed=7, + ) + init_state = { + 'init_params': np.array([10.0, 10.0]), + 'init_representation_params': np.array([10.0, 10.0]), + } + algo.initialize(init_state) + eval_results = [ + worker_util.EvaluationResult(np.array([10.0, 11.0]), 10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.empty(0), 0), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10.0, 11.0]), 10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10.0, 9.0]), -10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10.0, 9.0]), -10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.empty(0), 0), # pytype: disable=wrong-arg-types # numpy-scalars + ] + algo.process_evaluations(eval_results) + np.testing.assert_array_equal(algo._opt_params, np.array([10, 11])) + + +if __name__ == '__main__': + absltest.main() diff --git a/iris/algorithms/pyglove_algorithm.py b/iris/algorithms/pyglove_algorithm.py new file mode 100644 index 0000000..158c3ac --- /dev/null +++ b/iris/algorithms/pyglove_algorithm.py @@ -0,0 +1,115 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Algorithm class for any PyGlove controller-only algorithm.""" +import functools +from multiprocessing import dummy as mp_threads +from typing import Any, Dict, Sequence +from iris import worker_util +from iris.algorithms import algorithm +from iris.algorithms import controllers +import numpy as np +import pyglove as pg + + +class PyGloveAlgorithm(algorithm.BlackboxAlgorithm): + """Uses a PyGlove algorithm end-to-end for entire Blackbox Algorithm.""" + + def __init__(self, + controller_str: str = "regularized_evolution", + multithreading: bool = False, + **kwargs) -> None: + """Initializes the PyGlove algorithm. + + Args: + controller_str: Which controller algorithm to use on PyGlove side. + multithreading: Whether to multithread PyGlove DNA serialization. Pool + created after __init__ to avoid Launchpad pickling issues. + **kwargs: Arguments to parent BlackboxAlgorithm class. + """ + super().__init__(**kwargs) + self._controller_fn = functools.partial( + controllers.CONTROLLER_DICT[controller_str], + batch_size=self._num_suggestions) + + self._multithreading = multithreading + + def initialize(self, state: Dict[str, Any]) -> None: + if self._multithreading: + self._pool = mp_threads.Pool(self._num_suggestions) + + self._dna_spec = pg.from_json_str(state["serialized_dna_spec"]) + self._controller = self._controller_fn(dna_spec=self._dna_spec) + self._evaluated_serialized_dnas = [] + self._evaluated_rewards = [] + + def process_evaluations( + self, eval_results: Sequence[worker_util.EvaluationResult]) -> None: + eval_metadatas = [] + eval_rewards = [] + for eval_result in eval_results: + if eval_result.metadata: + eval_metadatas.append(eval_result.metadata) + eval_rewards.append(eval_result.value) + + def proper_unserialize(metadata: str) -> pg.DNA: + dna = pg.from_json_str(metadata) + # Put back the DNASpec into DNA, since serialization removed it. + dna.use_spec(self._dna_spec) + return dna + + if self._multithreading: + dna_list = self._pool.map(proper_unserialize, eval_metadatas) # pytype:disable=attribute-error + else: + dna_list = map(proper_unserialize, eval_metadatas) + dna_list = list(dna_list) + + for dna in dna_list: + + dna.use_spec(self._dna_spec) + self._controller.collect_rewards_and_train(eval_rewards, dna_list) + + def get_param_suggestions(self, + evaluate: bool = False) -> Sequence[Dict[str, Any]]: + vanilla_suggestions = [] + + dna_list = [ + self._controller.propose_dna() for _ in range(self._num_suggestions) + ] + # Note that for faster serialization, DNASpec is removed from DNA. + if self._multithreading: + metadata_list = self._pool.map(pg.to_json_str, dna_list) # pytype:disable=attribute-error + else: + metadata_list = map(pg.to_json_str, dna_list) + + metadata_list = list(metadata_list) + + for metadata in metadata_list: + suggestion = {"params_to_eval": np.empty((), dtype=np.float64)} + suggestion["metadata"] = metadata + vanilla_suggestions.append(suggestion) + + return vanilla_suggestions + + def _get_state(self) -> Dict[str, Any]: + vanilla_state = {} + vanilla_state["serialized_dna_spec"] = pg.to_json_str(self._dna_spec) # pytype:disable=attribute-error + vanilla_state["controller_alg_state"] = self._controller.get_state() # pytype:disable=attribute-error + return vanilla_state + + def _set_state(self, new_state: Dict[str, Any]) -> None: + self._interval_counter = new_state["interval_counter"] + self._dna_spec = pg.from_json_str(new_state["serialized_dna_spec"]) + self._controller = self._controller_fn(dna_spec=self._dna_spec) + self._controller.set_state(new_state["controller_alg_state"]) diff --git a/iris/algorithms/pyglove_algorithm_test.py b/iris/algorithms/pyglove_algorithm_test.py new file mode 100644 index 0000000..c70cff9 --- /dev/null +++ b/iris/algorithms/pyglove_algorithm_test.py @@ -0,0 +1,81 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from iris import worker_util +from iris.algorithms import pyglove_algorithm +import numpy as np +import pyglove as pg +from absl.testing import absltest +from absl.testing import parameterized + + +def make_init_state(): + dna_spec = pg.template(pg.one_of(['a', 'b'])).dna_spec() + return {'serialized_dna_spec': pg.to_json_str(dna_spec)} + + +def make_evaluation_results(suggestion_list): + eval_results = [] + for suggestion in suggestion_list[:-1]: + evaluation_result = worker_util.EvaluationResult( + params_evaluated=suggestion['params_to_eval'], # FYI Empty array. + value=np.random.uniform(), + metadata=suggestion['metadata']) + eval_results.append(evaluation_result) + eval_results.append(worker_util.EvaluationResult(np.empty(0), 0)) + return eval_results + + +class PygloveAlgorithmTest(parameterized.TestCase): + + def setUp(self): + self.num_suggestions = 100 + self.random_seed = 7 + super().setUp() + + @parameterized.named_parameters( + ('hill_climb', 'hill_climb'), ('neat', 'neat'), + ('policy_gradient', 'policy_gradient'), + ('random_search', 'random_search'), + ('regularized_evolution', 'regularized_evolution')) + def test_pyglove_algo_step(self, controller_str): + algo = pyglove_algorithm.PyGloveAlgorithm( + controller_str=controller_str, + num_suggestions=self.num_suggestions, + random_seed=self.random_seed) + + init_state = make_init_state() + algo.initialize(init_state) + + suggestion_list = algo.get_param_suggestions(evaluate=False) + eval_results = make_evaluation_results(suggestion_list) + algo.process_evaluations(eval_results) + + @parameterized.named_parameters(('False', False), ('True', True)) + def test_multithreading(self, multithreading): + algo = pyglove_algorithm.PyGloveAlgorithm( + multithreading=multithreading, + num_suggestions=self.num_suggestions, + random_seed=self.random_seed) + + init_state = make_init_state() + algo.initialize(init_state) + + suggestion_list = algo.get_param_suggestions(evaluate=False) + eval_results = make_evaluation_results(suggestion_list) + algo.process_evaluations(eval_results) + + +if __name__ == '__main__': + absltest.main() diff --git a/iris/algorithms/pyribs_algorithm.py b/iris/algorithms/pyribs_algorithm.py new file mode 100644 index 0000000..36f0576 --- /dev/null +++ b/iris/algorithms/pyribs_algorithm.py @@ -0,0 +1,248 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Algorithm class for PyRibs Quality Diversity Search. + +See https://arxiv.org/abs/2303.00191 for a description of this library and +summary of Quality Diversity search algorithms. + +For simplicity, this implementation only exposes a subset of the functionality +in PyRibs. Currently just Covariance Matrix Adaptation MAP-Elites (CMA-ME) with +a grid-based archive. +""" + +import dataclasses +from typing import Any, Dict, Sequence + +from iris import normalizer +from iris import worker_util +from iris.algorithms import algorithm +import numpy as np +from ribs import archives +from ribs import emitters +from ribs import schedulers +from typing_extensions import override + + +_ARCHIVE_DATA = "archive_data" +# Pyribs internal column names for the archive. +_INDEX = "index" +_SOLUTION = "solution" +# Extra column names for storing normalizer data with solutions. +_OBS_NORM_PREFIX = "obs_norm_" +_OBS_NORM_MEAN = _OBS_NORM_PREFIX + normalizer.MEAN +_OBS_NORM_STD = _OBS_NORM_PREFIX + normalizer.STD +_OBS_NORM_N = _OBS_NORM_PREFIX + normalizer.N + + +@dataclasses.dataclass(frozen=True) +class MeasureSpec: + """Specifications for behavior measures.""" + + # Name of the behavior measure, must be a metric exported blackbox. + name: str + # Range of values the measure can take. + range: tuple[float, float] + # Number of buckets to divide the above range into. + num_buckets: int + + +class PyRibsAlgorithm(algorithm.BlackboxAlgorithm): + """Quality Diversity search for the blackbox optimization framework. + + Defines a quality diversity search that can be executed with the blackbox + optimization framework (BBV2). The search uses the CMA-ME algorithm and a grid + archive for tracking solutions. + """ + + def __init__( + self, + measure_specs: Sequence[MeasureSpec], + obs_norm_data_buffer: normalizer.MeanStdBuffer, + initial_step_size: float, + num_suggestions_per_emitter: int, + num_emitters: int, + num_evals: int, + qd_score_offset: float = 0, + solution_ranker: str = "2imp", + ) -> None: + """Initializes a PyRibsAlgorithm. + + Args: + measure_specs: List of behevaior measure to optimize over. These must be + defined in metrics exported by the workers. + obs_norm_data_buffer: Buffer to sync statistics from all workers for + online mean std observation normalizer. + initial_step_size: Starting step size of the search. + num_suggestions_per_emitter: Number of suggestions each emitter. Total + suggestions = num_suggestions_per_emitter * num_emitters. + num_emitters: Number of suggestion emitters. More emitters imples more + varied exploration. + num_evals: Number of evaluations to perform on the top solution for + reporting the top score. + qd_score_offset: Value to add to rewards such that good solutions are + non-negative, see ribs/archives/_archive_base.py + for details. Default of 0 means no adjustment is applied and negative + rewards are not admitted to the archive. + solution_ranker: String abbreviation of the ranker for emitting solutions. + See ribs/emitters/rankers.py;l=13;rcl=601096249 + for options. Default is TwoStageRandomDirectionRanker. + """ + self._initial_step_size = initial_step_size + self._num_suggestions_per_emitter = num_suggestions_per_emitter + self._num_emitters = num_emitters + self._obs_norm_data_buffer = obs_norm_data_buffer + self._measure_specs = measure_specs + self._qd_score_offset = qd_score_offset + self._solution_ranker = solution_ranker + + self._measure_names = [measure.name for measure in measure_specs] + self._archive_dims = [measure.num_buckets for measure in measure_specs] + self._archive_ranges = [measure.range for measure in self._measure_specs] + self._opt_params = np.empty(0) + self._scheduler = None + self._init_scheduler() + super().__init__( + num_suggestions=num_suggestions_per_emitter * num_emitters, + random_seed=42, # Unused. + num_evals=num_evals, + ) + + def _init_scheduler( + self, saved_archive: dict[str, np.ndarray] | None = None + ) -> None: + """Initializes the archive and scheduler for PyRibs. + + Args: + saved_archive: Optional saved archive state to restore from. + """ + + # TODO: Eventually have a `state_spec` in the buffer class. + buffer_state = self._obs_norm_data_buffer.state + self._archive = archives.GridArchive( + solution_dim=self._opt_params.size, + dims=self._archive_dims, + ranges=self._archive_ranges, + qd_score_offset=self._qd_score_offset, + extra_fields={ + _OBS_NORM_MEAN: (buffer_state[normalizer.MEAN].size, np.float32), + _OBS_NORM_STD: (buffer_state[normalizer.STD].size, np.float32), + _OBS_NORM_N: ((), np.int32), + }, + ) + + if saved_archive is not None: + del saved_archive[_INDEX] # Index is not needed to restore state. + self._archive.add(**saved_archive) + + archive_emitters = [ + emitters.EvolutionStrategyEmitter( + archive=self._archive, + x0=self._opt_params.flatten(), + sigma0=self._initial_step_size, + ranker=self._solution_ranker, + batch_size=self._num_suggestions_per_emitter, + ) + for _ in range(self._num_emitters) + ] + self._scheduler = schedulers.Scheduler(self._archive, archive_emitters) + + @override + def initialize(self, state: dict[str, Any]): + self._opt_params = state[algorithm.PARAMS_TO_EVAL] + if algorithm.OBS_NORM_BUFFER_STATE in state: + self._obs_norm_data_buffer.data = state[algorithm.OBS_NORM_BUFFER_STATE] + self._init_scheduler() + + @override + def get_param_suggestions( + self, evaluate: bool = False + ) -> Sequence[Dict[str, Any]]: + if evaluate and self._archive.best_elite is None: + return [] + + if evaluate: + elite = self._archive.best_elite + param_suggestions = [elite[_SOLUTION]] * self._num_evals + buffer = { + normalizer.N: elite[_OBS_NORM_N], + normalizer.MEAN: elite[_OBS_NORM_MEAN], + normalizer.STD: elite[_OBS_NORM_STD], + } + else: + param_suggestions = self._scheduler.ask() + buffer = self._obs_norm_data_buffer.state + + return [ + { + algorithm.PARAMS_TO_EVAL: params, + algorithm.OBS_NORM_BUFFER_STATE: buffer, + algorithm.UPDATE_OBS_NORM_BUFFER: not evaluate, + } + for params in param_suggestions + ] + + def process_evaluations( + self, eval_results: Sequence[worker_util.EvaluationResult] + ) -> None: + objective = [] + measures = [] + obs_norm_n = [] + obs_norm_std = [] + obs_norm_mean = [] + for result in eval_results: + self._obs_norm_data_buffer.merge(result.obs_norm_buffer_data) + objective.append(result.value) + measures.append([result.metrics[name] for name in self._measure_names]) + obs_norm_n.append(result.obs_norm_buffer_data[normalizer.N]) + obs_norm_std.append(result.obs_norm_buffer_data[normalizer.STD]) + obs_norm_mean.append(result.obs_norm_buffer_data[normalizer.MEAN]) + + # Store the state of the obs_norm_buffer for each solution so that it can be + # reproduced later when evaluating the policy, similar to other algorithms + # that use the Blackbox framework. + extra_fields = { + _OBS_NORM_MEAN: obs_norm_mean, + _OBS_NORM_STD: obs_norm_std, + _OBS_NORM_N: obs_norm_n, + } + + self._scheduler.tell( + objective=objective, + measures=measures, + **extra_fields, + ) + + @property + @override + def state(self) -> Dict[str, Any]: + return { + algorithm.PARAMS_TO_EVAL: self._opt_params, + algorithm.OBS_NORM_BUFFER_STATE: self._obs_norm_data_buffer.state, + # Phoenix cannot serialize the archive directly, so use the dict state. + _ARCHIVE_DATA: self._archive.data(), + } + + @override + def restore_state_from_checkpoint(self, new_state: Dict[str, Any]) -> None: + self.state = new_state + + @state.setter + @override + def state(self, new_state: Dict[str, Any]) -> None: + self._opt_params = new_state[algorithm.PARAMS_TO_EVAL] + self._obs_norm_data_buffer.state = new_state[ + algorithm.OBS_NORM_BUFFER_STATE + ] + self._init_scheduler(new_state.get(_ARCHIVE_DATA, None)) diff --git a/iris/algorithms/pyribs_algorithm_test.py b/iris/algorithms/pyribs_algorithm_test.py new file mode 100644 index 0000000..7483561 --- /dev/null +++ b/iris/algorithms/pyribs_algorithm_test.py @@ -0,0 +1,264 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from iris import normalizer +from iris import worker_util +from iris.algorithms import algorithm +from iris.algorithms import pyribs_algorithm +import numpy as np +from ribs import archives + +from absl.testing import absltest + +# Define two arbitrary specs with different ranges so we can test for them. +_X_SPEC = pyribs_algorithm.MeasureSpec('x', (0, 10), 10) +_Y_SPEC = pyribs_algorithm.MeasureSpec('y', (1, 100), 20) + + +class PyribsAlgorithmTest(absltest.TestCase): + + def setUp(self): + super().setUp() + # Basic parameters chosen to be simple enough to not distract from the + # algorithm logic but with enough complexity to test functionality e.g. + # using multiple measure specs. + self.buffer = normalizer.MeanStdBuffer(shape=(8,)) + self.num_suggestions_per_emitter = 10 + self.num_emitters = 20 + self.initial_step_size = 1.0 + self.num_evals = 100 + + self.initial_params = np.ones((13,)) + + self.test_algorithm = pyribs_algorithm.PyRibsAlgorithm( + measure_specs=[_X_SPEC, _Y_SPEC], + obs_norm_data_buffer=self.buffer, + initial_step_size=self.initial_step_size, + num_suggestions_per_emitter=self.num_suggestions_per_emitter, + num_emitters=self.num_emitters, + num_evals=self.num_evals, + ) + self.test_algorithm.initialize( + {algorithm.PARAMS_TO_EVAL: self.initial_params} + ) + + def test_initialize_with_obs_norm_state(self): + self.test_algorithm.initialize({ + algorithm.PARAMS_TO_EVAL: self.initial_params, + algorithm.OBS_NORM_BUFFER_STATE: self.buffer.state, + }) + + np.testing.assert_equal( + self.buffer.state, + self.test_algorithm.state[algorithm.OBS_NORM_BUFFER_STATE], + ) + np.testing.assert_equal( + self.initial_params, self.test_algorithm.state[algorithm.PARAMS_TO_EVAL] + ) + + def test_get_param_suggestions(self): + suggestions = self.test_algorithm.get_param_suggestions() + + self.assertLen( + suggestions, self.num_emitters * self.num_suggestions_per_emitter + ) + for suggestion in suggestions: + self.assertLen( + suggestion[algorithm.PARAMS_TO_EVAL], self.initial_params.size + ) + np.testing.assert_equal( + suggestion[algorithm.OBS_NORM_BUFFER_STATE], self.buffer.state + ) + self.assertTrue(suggestion[algorithm.UPDATE_OBS_NORM_BUFFER]) + + def test_get_param_suggestions_for_eval_is_empty_initially(self): + self.assertEmpty(self.test_algorithm.get_param_suggestions(evaluate=True)) + + def test_get_param_suggestions_for_eval(self): + suggestions = self.test_algorithm.get_param_suggestions() + evaluations = [ + worker_util.EvaluationResult( + params_evaluated=suggestion[algorithm.PARAMS_TO_EVAL], + value=1, + obs_norm_buffer_data=suggestion[ + algorithm.OBS_NORM_BUFFER_STATE + ] | {normalizer.N: 1, normalizer.UNNORM_VAR: np.ones((8,))}, + metrics={'x': 1, 'y': 10}, + ) + for suggestion in suggestions + ] + # Give the first evaluation a high score so it is the elite. + evaluations[0].value = 1000 + if evaluations[0].obs_norm_buffer_data is not None: + evaluations[0].obs_norm_buffer_data[normalizer.N] = 1000 + self.test_algorithm.process_evaluations(evaluations) + + eval_suggestions = self.test_algorithm.get_param_suggestions(evaluate=True) + + self.assertLen(eval_suggestions, self.num_evals) + for eval_suggestion in eval_suggestions: + np.testing.assert_equal( + eval_suggestion[algorithm.PARAMS_TO_EVAL], + evaluations[0].params_evaluated + ) + np.testing.assert_equal( + eval_suggestion[algorithm.OBS_NORM_BUFFER_STATE][normalizer.N], + evaluations[0].obs_norm_buffer_data[normalizer.N], + ) + self.assertFalse(eval_suggestion[algorithm.UPDATE_OBS_NORM_BUFFER]) + + def test_restore_state_from_checkpoint_without_archive(self): + checkpoint_buffer = normalizer.MeanStdBuffer(shape=(8,)) + checkpoint_buffer.push(np.ones(8,)) + checkpoint_state = { + algorithm.PARAMS_TO_EVAL: np.zeros((13,)), + algorithm.OBS_NORM_BUFFER_STATE: checkpoint_buffer.state, + } + self.test_algorithm.restore_state_from_checkpoint(checkpoint_state) + state_after_checkpoint = self.test_algorithm.state + + np.testing.assert_equal( + state_after_checkpoint[algorithm.PARAMS_TO_EVAL], + checkpoint_state[algorithm.PARAMS_TO_EVAL], + ) + np.testing.assert_equal( + state_after_checkpoint[algorithm.OBS_NORM_BUFFER_STATE], + checkpoint_buffer.state, + ) + # Archive not restored so it should be empty. + self.assertEmpty( + state_after_checkpoint[pyribs_algorithm._ARCHIVE_DATA][ + pyribs_algorithm._SOLUTION + ] + ) + + def test_restore_state_from_checkpoint_with_archive(self): + checkpoint_buffer = normalizer.MeanStdBuffer(shape=(8,)) + checkpoint_buffer.push( + np.ones( + 8, + ) + ) + buffer_state = checkpoint_buffer.state + checkpoint_archive = archives.GridArchive( + solution_dim=self.initial_params.size, + dims=(_X_SPEC.num_buckets, _Y_SPEC.num_buckets), + ranges=(_X_SPEC.range, _Y_SPEC.range), + qd_score_offset=0, + extra_fields={ + pyribs_algorithm._OBS_NORM_MEAN: ( + buffer_state[normalizer.MEAN].size, + np.float32, + ), + pyribs_algorithm._OBS_NORM_STD: ( + buffer_state[normalizer.STD].size, + np.float32, + ), + pyribs_algorithm._OBS_NORM_N: ((), np.int32), + }, + ) + # Add 3 solutions to the archive to be restored. + checkpoint_archive.add( + solution=[np.ones((13,)), np.ones((13,))*2, np.ones((13,))*3], + objective=[1, 2, 3], + measures=[(1, 10), (2, 20), (3, 30)], + obs_norm_mean=[np.ones((8,)), np.ones((8,))*2, np.ones((8,))*3], + obs_norm_std=[np.ones((8,)), np.ones((8,))*2, np.ones((8,))*3], + obs_norm_n=[1, 2, 3], + ) + + checkpoint_state = { + algorithm.PARAMS_TO_EVAL: np.zeros((13,)), + algorithm.OBS_NORM_BUFFER_STATE: checkpoint_buffer.state, + pyribs_algorithm._ARCHIVE_DATA: checkpoint_archive.data(), + } + self.test_algorithm.restore_state_from_checkpoint(checkpoint_state) + state_after_checkpoint = self.test_algorithm.state + + np.testing.assert_equal( + state_after_checkpoint[algorithm.PARAMS_TO_EVAL], + checkpoint_state[algorithm.PARAMS_TO_EVAL], + ) + np.testing.assert_equal( + state_after_checkpoint[algorithm.OBS_NORM_BUFFER_STATE], + checkpoint_buffer.state, + ) + # Archive should have 3 restored elements in it. + self.assertLen( + state_after_checkpoint[pyribs_algorithm._ARCHIVE_DATA][ + pyribs_algorithm._SOLUTION + ], + 3, + ) + + def test_process_evaluations(self): + evaluations = [ + worker_util.EvaluationResult( + params_evaluated=np.ones((13,)), + value=1, + obs_norm_buffer_data={ + normalizer.N: 1, + normalizer.STD: np.ones((8,)), + normalizer.MEAN: np.ones((8,)), + normalizer.UNNORM_VAR: np.ones((8,)), + }, + metrics={'x': 1, 'y': 10}, + ), + worker_util.EvaluationResult( + params_evaluated=np.ones((13,) * 2), + value=2, + obs_norm_buffer_data={ + normalizer.N: 2, + normalizer.STD: np.ones((8,)) * 2, + normalizer.MEAN: np.ones((8,)) * 2, + normalizer.UNNORM_VAR: np.ones((8,)) * 2, + }, + metrics={'x': 2, 'y': 20}, + ), + ] + + with mock.patch.object( + self.test_algorithm._scheduler, 'tell', autospec=True + ) as mock_tell: + self.test_algorithm.process_evaluations(evaluations) + + # There are no outputs, so just check the function was called correctly. + mock_tell.assert_called_once() + self.assertEqual(mock_tell.call_args.kwargs['objective'], [1, 2]) + self.assertEqual( + mock_tell.call_args.kwargs['measures'], [[1, 10], [2, 20]] + ) + np.testing.assert_equal( + mock_tell.call_args.kwargs[pyribs_algorithm._OBS_NORM_MEAN], + [ + np.ones((8,)), + np.ones((8,)) * 2, + ], + ) + np.testing.assert_equal( + mock_tell.call_args.kwargs[pyribs_algorithm._OBS_NORM_STD], + [ + np.ones((8,)), + np.ones((8,)) * 2, + ], + ) + self.assertEqual( + mock_tell.call_args.kwargs[pyribs_algorithm._OBS_NORM_N], [1, 2] + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/iris/algorithms/rbo_algorithm.py b/iris/algorithms/rbo_algorithm.py new file mode 100644 index 0000000..6196b29 --- /dev/null +++ b/iris/algorithms/rbo_algorithm.py @@ -0,0 +1,102 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Algorithm class for the Robust Blackbox Optimization (RBO) algorithm. + +Algorithm class for the Robust Blackbox Optimization (RBO) algorithm from the +paper: Provably Robust Blackbox Optimization for Reinforcement Learning +(https://arxiv.org/abs/1903.02993) (CoRL 2021). +""" + +from typing import Sequence + +from iris import worker_util +from iris.algorithms import ars_algorithm +from iris.algorithms import optimizers +import numpy as np + + +class RBO(ars_algorithm.AugmentedRandomSearch): + """Robust Blackbox Optimization Algorithm.""" + + def __init__(self, + regression_method: str = "ridge", + regularizer: float = 0.01, + **kwargs) -> None: + """Initializes the augmented random search algorithm. + + Args: + regression_method: type of the regression method used for grad retrieval. + Currently supported methods include: LP-decoding ("lp"), Lasso + regression ("lasso") and ridge regression ("ridge"). + regularizer: regression regularizer used for gradient retrieval. + **kwargs: Other keyword arguments for base class. + """ + super().__init__(**kwargs) + if regression_method == "lasso": + self._regression_method = optimizers.lasso_regression_jacobian_decoder + elif regression_method == "lp": + self._regression_method = optimizers.l1_jacobian_decoder + elif regression_method == "ridge": + self._regression_method = optimizers.ridge_regression_jacobian_decoder + else: + raise ValueError("Invalid regression_method") + self._regularizer = regularizer + + def process_evaluations( + self, eval_results: Sequence[worker_util.EvaluationResult]) -> None: + """Processes the list of Blackbox function evaluations return from workers. + + Gradient is computed by applying a particular regression procedure. The + current parameter vector is then updated in the gradient direction with + specified step size. + + Args: + eval_results: List containing Blackbox function evaluations based on the + order in which the suggestions were sent. RBO performs gradient-based + update, where gradient is retrieved via a regression procedure. The + particular type of the regression procedure applied is specified in the + constructor. + """ + + # Retrieve delta direction from the param suggestion sent for evaluation. + pos_eval_results = eval_results[:self._num_suggestions] + neg_eval_results = eval_results[self._num_suggestions:] + filtered_pos_eval_results = [] + filtered_neg_eval_results = [] + for i in range(len(pos_eval_results)): + if (pos_eval_results[i].params_evaluated.size) and ( + neg_eval_results[i].params_evaluated.size): + filtered_pos_eval_results.append(pos_eval_results[i]) + filtered_neg_eval_results.append(neg_eval_results[i]) + params = np.array([r.params_evaluated for r in filtered_pos_eval_results]) + perturbations = params - self._opt_params + eval_results = filtered_pos_eval_results + filtered_neg_eval_results + pos_evals = np.array([r.value for r in filtered_pos_eval_results]) + neg_evals = np.array([r.value for r in filtered_neg_eval_results]) + evals = (pos_evals - neg_evals) / 2.0 + + # Estimate gradients via regression. + gradient = self._regression_method( + np.transpose(perturbations), np.expand_dims(evals, 1), + self._regularizer) + gradient = np.reshape(gradient, (len(gradient[0]))) + + # Apply gradients + self._opt_params += self._step_size * gradient + + # Update the observation buffer + if self._obs_norm_data_buffer is not None: + for r in eval_results: + self._obs_norm_data_buffer.merge(r.obs_norm_buffer_data) diff --git a/iris/algorithms/rbo_algorithm_test.py b/iris/algorithms/rbo_algorithm_test.py new file mode 100644 index 0000000..b07be70 --- /dev/null +++ b/iris/algorithms/rbo_algorithm_test.py @@ -0,0 +1,84 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from iris import worker_util +from iris.algorithms import rbo_algorithm +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + + +class AlgorithmTest(parameterized.TestCase): + + def test_rbo_gradient(self): + algo = rbo_algorithm.RBO( + num_suggestions=4, + step_size=0.5, + std=1., + regularizer=0.01, + regression_method='lasso', + random_seed=7) + init_state = {'init_params': np.array([10., 10.])} + algo.initialize(init_state) + eval_results = [ + worker_util.EvaluationResult(np.array([10., 11.]), 10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 9.]), -10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.empty(0), 0), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 11.]), 10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 11.]), 10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 9.]), -10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 9.]), -10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.empty(0), 0), # pytype: disable=wrong-arg-types # numpy-scalars + ] + algo.process_evaluations(eval_results) + np.testing.assert_array_almost_equal( + algo._opt_params, np.array([10., 10.]), decimal=3) + + @parameterized.parameters( + ('lasso', False, False), + ('ridge', False, False), + ('lp', False, False), + ('lasso', True, False), + ('lasso', False, True), + ('lasso', True, True), + ) + def test_rbo_gradient_2(self, regression_method, orthogonal_suggestions, + quasirandom_suggestions): + algo = rbo_algorithm.RBO( + num_suggestions=4, + step_size=0.5, + std=1., + regularizer=0.01, + orthogonal_suggestions=orthogonal_suggestions, + quasirandom_suggestions=quasirandom_suggestions, + regression_method=regression_method, + random_seed=7) + init_state = {'init_params': np.array([10., 10.])} + algo.initialize(init_state) + eval_results = [ + worker_util.EvaluationResult(np.array([10., 11.]), 10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 9.]), -10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.empty(0), 0), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 11.]), 10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 11.]), 10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 9.]), -10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.array([10., 9.]), -10), # pytype: disable=wrong-arg-types # numpy-scalars + worker_util.EvaluationResult(np.empty(0), 0), # pytype: disable=wrong-arg-types # numpy-scalars + ] + algo.process_evaluations(eval_results) + np.testing.assert_equal(len(algo._opt_params), 2) + + +if __name__ == '__main__': + absltest.main() diff --git a/iris/algorithms/run_split_checkpoint.py b/iris/algorithms/run_split_checkpoint.py new file mode 100644 index 0000000..a7d0b1a --- /dev/null +++ b/iris/algorithms/run_split_checkpoint.py @@ -0,0 +1,59 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Splits a Blackbox v2 checkpoint into checkpoints per agent.""" + +from collections.abc import Sequence + +from absl import app +from absl import flags +from iris import normalizer +from iris.algorithms import ars_algorithm + +_NUM_AGENTS = flags.DEFINE_integer('num_agents', 2, 'Number of agents.') +_HAS_OBS_NORM = flags.DEFINE_boolean( + 'has_obs_norm', True, + 'Whether the checkpoint has observation normalization') +_CHECKPOINT_PATH = flags.DEFINE_string( + 'checkpoint_path', None, 'Path to checkpoint.', required=True) + + +def split_and_save_checkpoint(checkpoint_path: str, + num_agents: int = 2, + has_obs_norm_data_buffer: bool = False) -> None: + """Splits the checkpoint at checkpoint_path into num_agents checkpoints.""" + algo = ars_algorithm.MultiAgentAugmentedRandomSearch( + num_suggestions=3, + step_size=0.5, + std=1.0, + top_percentage=1, + orthogonal_suggestions=True, + quasirandom_suggestions=True, + obs_norm_data_buffer=normalizer.MeanStdBuffer() + if has_obs_norm_data_buffer else None, + agent_keys=[str(i) for i in range(num_agents)], + random_seed=7) + algo.split_and_save_checkpoint(checkpoint_path=checkpoint_path) + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + split_and_save_checkpoint(checkpoint_path=_CHECKPOINT_PATH.value, + num_agents=_NUM_AGENTS.value, + has_obs_norm_data_buffer=_HAS_OBS_NORM.value) + + +if __name__ == '__main__': + app.run(main) diff --git a/iris/algorithms/stateless_perturbation_generators.py b/iris/algorithms/stateless_perturbation_generators.py new file mode 100644 index 0000000..c4118cd --- /dev/null +++ b/iris/algorithms/stateless_perturbation_generators.py @@ -0,0 +1,348 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Library for generating ensembles of perturbations for Blackbox training. + +A library for generating ensembles of perturbations for Blackbox training. +The generator is stateless, e.g. does not need a state in the current point of +the optimization to calculatee the directions. Thus the library supports in +particular a rich class of algorithms generating structured (orthogonal or +quasi-orthogonal, QMC) ensembles. +""" + +import abc +import math + +import numpy as np +import scipy.stats as ss + + +class MatrixGenerator(metaclass=abc.ABCMeta): + r"""Abstract class for generting matrices with rows encoding perturbations. + + Class is responsible for constructing matrices with rows encoding + perturbations for the Blackbox training. The matrices are of the shape [m,d], + where m stands for number of perturbations and d for perturbations' + dimensionality. + """ + + @abc.abstractmethod + def generate_matrix(self): + r"""Returns the generated 2D matrix. + + Creates a 2D matrix. + + Args: + + Returns: + Generated 2D matrix. + """ + raise NotImplementedError('Abstract method') + + +class GaussianUnstructuredMatrixGenerator(MatrixGenerator): + r"""Derives from MatrixGenerator and creates a Gaussian matrix. + + Class responsible for constructing unstructured Gaussian matrix with entries + taken independently at random from N(0,1). + """ + + def __init__(self, num_suggestions, dim): + self.num_suggestions = num_suggestions + self.dim = dim + super().__init__() + + def generate_matrix(self): + return np.random.normal(size=(self.num_suggestions, self.dim)) + + +class GaussianOrthogonalMatrixGenerator(MatrixGenerator): + r"""Derives from MatrixGenerator and creates Gaussian orthogonal matrix. + + Class responsible for constructing block-orthogonal Gaussian matrix with: + different blocks constructed independently, orthogonal rows within a fixed + d x d block and marginal distributions of rows N(0,I_d). + """ + + def __init__(self, num_suggestions, dim, deterministic_lengths): + self.num_suggestions = num_suggestions + self.dim = dim + self.deterministic_lengths = deterministic_lengths + super().__init__() + + def generate_matrix(self): + nb_full_blocks = int(self.num_suggestions / self.dim) + block_list = [] + for _ in range(nb_full_blocks): + unstructured_block = np.random.normal(size=(self.dim, self.dim)) + q, _ = np.linalg.qr(unstructured_block) + q = np.transpose(q) + block_list.append(q) + remaining_rows = self.num_suggestions - nb_full_blocks * self.dim + if remaining_rows > 0: + unstructured_block = np.random.normal(size=(self.dim, self.dim)) + q, _ = np.linalg.qr(unstructured_block) + q = np.transpose(q) + block_list.append(q[0:remaining_rows]) + final_matrix = np.vstack(block_list) + + if not self.deterministic_lengths: + multiplier = np.linalg.norm( + np.random.normal(size=(self.num_suggestions, self.dim)), axis=1) + else: + multiplier = np.sqrt(float(self.dim)) * np.ones((self.num_suggestions)) + + return np.matmul(np.diag(multiplier), final_matrix) + + +class SignMatrixGenerator(MatrixGenerator): + r"""Derives from MatrixGenerator and creates Sign matrix. + + Class responsible for constructing a matrix with random entries from {-1,+1}. + """ + + def __init__(self, num_suggestions, dim): + self.num_suggestions = num_suggestions + self.dim = dim + super().__init__() + + def generate_matrix(self): + return np.sign(np.random.normal(size=(self.num_suggestions, self.dim))) + + +class SphereMatrixGenerator(MatrixGenerator): + r"""Derives from MatrixGenerator and creates normalized Gaussian matrix. + + Class responsible for constructing a normalized Gaussian matrix with rows of + length sqrt{d}. + """ + + def __init__(self, num_suggestions, dim): + self.num_suggestions = num_suggestions + self.dim = dim + super().__init__() + + def generate_matrix(self): + gaussian_unnormalized = np.random.normal( + size=(self.num_suggestions, self.dim)) + lengths = np.linalg.norm(gaussian_unnormalized, axis=1, keepdims=True) + return np.sqrt(self.dim) * (gaussian_unnormalized / lengths) + + +class RandomHadamardMatrixGenerator(MatrixGenerator): + r"""Derives from MatrixGenerator and creates random Hadamard matrix. + + Class responsible for constructing a random Hadamard matrix HD, where + H is a Kronecker-product Hadamard and D is a random diagonal matrix with + entries on the diagonal taken independently and uniformly at random from the + two-element discrete set {-1,+1}. + Since H is a Kronecker-product Hadamard matrix, it is assumed that dim is + a power of two (if necessary, by padding extra zeros). + """ + + def __init__(self, num_suggestions, dim): + self.num_suggestions = num_suggestions + self.dim = dim + full_matrix_size = 1 + while full_matrix_size < dim: + full_matrix_size = 2 * full_matrix_size + nph = np.tile(1.0, (full_matrix_size, full_matrix_size)) + i = 1 + while i < full_matrix_size: + for j in range(i): + for k in range(i): + nph[j + i][k] = nph[j][k] + nph[j][k + i] = nph[j][k] + nph[j + i][k + i] = -nph[j][k] + i += i + + self.core_hadamard = nph + self.extended_dim = full_matrix_size + super().__init__() + + def generate_matrix(self): + ones = np.ones((self.extended_dim)) + minus_ones = np.negative(ones) + diagonal = np.where( + np.random.uniform(size=(self.extended_dim)) < 0.5, ones, minus_ones) + final_list = [] + for i in range(min(self.num_suggestions, self.extended_dim)): + pointwise_product = np.multiply(self.core_hadamard[i], diagonal) + final_list.append(pointwise_product) + return np.array(final_list) + + +def create_rect_kac_matrix(num_rows, dim, number_of_blocks, angles, + indices_pairs): + r"""Creates a rectangular Kac's random walk matrix for given rotations. + + Outputs a submatrix of the Kac's random walk matrix truncated to its first + rows for a given list of angles and indices defining + low-dimensional rotations. The Kac's random walk matrix is a product of + Givens rotations. Each Givens rotation is characterized by + its angle and a pair of indices characterizing 2-dimensional space spanned by + two canonical vectors, where the rotation occurs. + + Args: + num_rows: number of first rows output + dim: number of rows/columns of the full Kac's random walk matrix + number_of_blocks: number of Givens random rotations used to create full + Kac's random walk matrix + angles: list of angles used to construct Givens random rotations + indices_pairs: list of pairs of indices used to construct Givens random + rotations + + Returns: + Kac's random walk matrix. + """ + matrix_as_list = [] + for index in range(min(num_rows, dim)): + base_vector = np.zeros(dim) + np.put(base_vector, index, 1.0) + for j in range(number_of_blocks): + angle = angles[j] + p = indices_pairs[j][0] + q = indices_pairs[j][1] + if p > q: + u = p + p = q + q = u + base_vector[p] = math.cos(angle) * base_vector[p] - math.sin( + angle) * base_vector[q] + base_vector[q] = math.sin(angle) * base_vector[p] + math.cos( + angle) * base_vector[q] + matrix_as_list.append(base_vector) + return math.sqrt(float(dim)) * np.array(matrix_as_list) + + +class KacMatrixGenerator(MatrixGenerator): + r"""Derives from MatrixGenerator and creates Kac's random walk matrix. + + Class responsible for constructing a submatrix of the Kac's random walk + matrix obtained from the full Kac's random walk matrix by truncating it to its + first x rows. + """ + + def __init__(self, num_suggestions, dim, number_of_blocks): + r"""Constructor of the Kac's random walk matrix generator. + + Args: + num_suggestions: number of the rows of the constructed matrix + dim: the number of the columns of the constructed matrix + number_of_blocks: number of blocks of the applied Kac's random walk matrix + """ + self.num_suggestions = num_suggestions + self.dim = dim + self.number_of_blocks = number_of_blocks + super().__init__() + + def generate_matrix(self): + angles = np.random.uniform( + low=0.0, high=2.0 * np.pi, size=(self.number_of_blocks)) + indices_pairs = np.random.choice( + np.arange(self.dim), size=(self.number_of_blocks, 2)) + return create_rect_kac_matrix(self.num_suggestions, self.dim, + self.number_of_blocks, angles, indices_pairs) + + +def phi_reflection(b, x): + r"""Outputs \phi function used in the computation of Halton sequences. + + Outputs a function \phi_{b} defined as follows: + + \phi_{b}(y_{k}y_{k-1}...y_{0}_{b}) = y_{0}/b + y_{1}/b^{2} + ..., + + where y_{k}y_{k-1}...y_{0}_{b} stands for the representation of the number + using base b. + + Args: + b: base used by the \phi function + x: input to the \phi function + + Returns: + \phi_{b}(x) + """ + b_f = float(b) + x_f = float(x) + coefficients = [] + while x_f >= b_f: + w = math.floor(x_f / b_f) * b_f + r = x_f - w + coefficients.append(r) + x_f = math.floor(x_f / b_f) + coefficients.append(x_f) + x_f_reflection = 0.0 + power_of_b = b_f + for i in range(len(coefficients)): + x_f_reflection += coefficients[i] / power_of_b + power_of_b *= b_f + return x_f_reflection + + +def create_rect_hal_matrix(b_set, r_array): + r"""Creates a Halton matrix. + + Creates a Halton matrix using inputs from the list and base values + for \phi function parametrization from the list . + + Args: + b_set: set of base values + r_array: inputs to the \phi function + + Returns: + corresponding Halton matrix + """ + rows = [] + for i in range(len(b_set)): + next_row = [] + for j in range(len(r_array)): + next_row.append(ss.norm.ppf(phi_reflection(b_set[i], r_array[j]))) + rows.append(next_row) + return np.array(rows) + + +class HaltonMatrixGenerator(MatrixGenerator): + r"""Derives from MatrixGenerator and creates Halton matrix. + + Class responsible for constructing a subsampled version of the Halton matrix H + with rows defined as follows: + + h_{j} = (\phi_{b_{1}}(j),..., \phi_{b_{dim}}(j))^{T}, where + + b_{1},...,b_{dim} is a set of numbers such that gcd(b_{i},b_{k}) = 1 for + i != k and \phi_{y} is defined as follows: + + \phi_{y}(y_{0} + y_{1}*y + y_{2}*y^{2} + ...) = y_{0}/y + y_{1}/y^{2} + ... + + for y_{0},y_{1},...., < y. + """ + + def __init__(self, num_suggestions, dim, b_set): + r"""Constructor of the Halton matrix generator. + + Args: + num_suggestions: number of the rows of the constructed matrix + dim: the number of the columns of the constructed matrix + b_set: the list of bases: [b_{1},...,b_{dim}] (see: explanation above) + """ + self.num_suggestions = num_suggestions + self.dim = dim + self.b_set = b_set + super().__init__() + + def generate_matrix(self): + return np.transpose( + create_rect_hal_matrix( + np.array(self.b_set), + np.arange(min(self.num_suggestions + 1, self.dim))))[1:]