diff --git a/iris/algorithms/ars_algorithm_test.py b/iris/algorithms/ars_algorithm_test.py index 349551d..85ae15d 100644 --- a/iris/algorithms/ars_algorithm_test.py +++ b/iris/algorithms/ars_algorithm_test.py @@ -90,53 +90,68 @@ def test_restore_state_from_checkpoint(self, expected_obs_norm_state): else None, random_seed=7, ) - init_state = {'init_params': np.array([10., 10.])} + init_state = {'init_params': np.array([10.0, 10.0])} if expected_obs_norm_state: - init_state['obs_norm_buffer_data'] = {'mean': np.asarray([0., 0.]), - 'std': np.asarray([1., 1.]), - 'n': 0} + init_state['obs_norm_buffer_data'] = { + 'mean': np.asarray([0.0, 0.0]), + 'std': np.asarray([1.0, 1.0]), + 'n': 0, + } algo.initialize(init_state) - + # self.assertIsNotNone(algo._obs_norm_data_buffer) with self.subTest('init-mean'): - self.assertAllClose(np.array(algo._opt_params), - init_state['init_params']) - if expected_obs_norm_state is not None: + self.assertAllClose(np.array(algo._opt_params), init_state['init_params']) + if ( + expected_obs_norm_state is not None + and algo._obs_norm_data_buffer is not None + ): with self.subTest('init-obs-mean'): self.assertAllClose( np.asarray(algo._obs_norm_data_buffer.data['mean']), - np.asarray(init_state['obs_norm_buffer_data']['mean'])) + np.asarray(init_state['obs_norm_buffer_data']['mean']), + ) with self.subTest('init-obs-n'): self.assertAllClose( np.asarray(algo._obs_norm_data_buffer.data['n']), - np.asarray(init_state['obs_norm_buffer_data']['n'])) + np.asarray(init_state['obs_norm_buffer_data']['n']), + ) with self.subTest('init-obs-std'): self.assertAllClose( np.asarray(algo._obs_norm_data_buffer.data['std']), - init_state['obs_norm_buffer_data']['std']) + init_state['obs_norm_buffer_data']['std'], + ) - expected_restore_state = {'params_to_eval': np.array([5., 6.])} + expected_restore_state = {'params_to_eval': np.array([5.0, 6.0])} if expected_obs_norm_state is not None: expected_restore_state['obs_norm_state'] = expected_obs_norm_state algo.restore_state_from_checkpoint(expected_restore_state) - self.assertAllClose(algo._opt_params, - expected_restore_state['params_to_eval']) - if expected_obs_norm_state is not None: + self.assertAllClose( + algo._opt_params, expected_restore_state['params_to_eval'] + ) + if ( + expected_obs_norm_state is not None + and algo._obs_norm_data_buffer is not None + ): std = expected_restore_state['obs_norm_state']['std'] var = np.square(std) expected_unnorm_var = var * 4 with self.subTest('restore-obs-mean'): self.assertAllClose( np.asarray(algo._obs_norm_data_buffer.data['mean']), - np.asarray(expected_restore_state['obs_norm_state']['mean'])) + np.asarray(expected_restore_state['obs_norm_state']['mean']), + ) with self.subTest('restore-obs-n'): self.assertAllClose( np.asarray(algo._obs_norm_data_buffer.data['n']), - np.asarray(expected_restore_state['obs_norm_state']['n'])) + np.asarray(expected_restore_state['obs_norm_state']['n']), + ) with self.subTest('restore-obs-std'): self.assertAllClose( np.asarray(algo._obs_norm_data_buffer.data['unnorm_var']), - expected_unnorm_var) + expected_unnorm_var, + ) + if __name__ == '__main__': absltest.main() diff --git a/iris/algorithms/multi_agent_ars_algorithm_test.py b/iris/algorithms/multi_agent_ars_algorithm_test.py index dfc78f4..67af314 100644 --- a/iris/algorithms/multi_agent_ars_algorithm_test.py +++ b/iris/algorithms/multi_agent_ars_algorithm_test.py @@ -233,7 +233,10 @@ def test_restore_state_from_checkpoint( with self.subTest('init-mean'): self.assertAllClose(np.array(algo._opt_params), init_state['init_params']) - if state['obs_norm_state'] is not None: + if ( + state['obs_norm_state'] is not None + and algo._obs_norm_data_buffer is not None + ): with self.subTest('init-obs-mean'): self.assertAllClose( np.asarray(algo._obs_norm_data_buffer.data['mean']), @@ -253,7 +256,10 @@ def test_restore_state_from_checkpoint( algo.restore_state_from_checkpoint(state) self.assertAllClose(algo._opt_params, expected_state['params_to_eval']) - if expected_state['obs_norm_state'] is not None: + if ( + expected_state['obs_norm_state'] is not None + and algo._obs_norm_data_buffer is not None + ): std = expected_state['obs_norm_state']['std'] var = np.square(std) expected_unnorm_var = var * 4 diff --git a/iris/normalizer.py b/iris/normalizer.py index 8172f96..16ecd1d 100644 --- a/iris/normalizer.py +++ b/iris/normalizer.py @@ -18,7 +18,6 @@ import copy from typing import Any, Dict, Optional, Sequence, Union from absl import logging -import gin import gym from gym import spaces from gym.spaces import utils @@ -106,13 +105,16 @@ def state(self, new_state: Dict[str, Any]) -> None: pass -@gin.configurable class MeanStdBuffer(Buffer): """Collect stats for calculating mean and std online.""" def __init__(self, shape: Sequence[int] = (0,)) -> None: self._shape = shape - self._data = {} + self._data = { + N: 0, + MEAN: np.zeros(self._shape, dtype=np.float64), + UNNORM_VAR: np.zeros(self._shape, dtype=np.float64), + } self.reset() def reset(self) -> None: @@ -283,7 +285,6 @@ def state(self, state: Dict[str, np.ndarray]) -> None: self._state = state.copy() -@gin.configurable class NoNormalizer(Normalizer): """No Normalization applied to input.""" @@ -300,7 +301,6 @@ def __call__( return value -@gin.configurable class ActionRangeDenormalizer(Normalizer): """Actions mapped to given range from [-1, 1].""" @@ -341,7 +341,6 @@ def __call__( return action -@gin.configurable class ObservationRangeNormalizer(Normalizer): """Observations mapped from given range to [-1, 1].""" @@ -383,7 +382,6 @@ def __call__( return observation -@gin.configurable class RunningMeanStdNormalizer(Normalizer): """Standardize observations with mean and std calculated online.""" diff --git a/iris/normalizer_test.py b/iris/normalizer_test.py index d32f47f..4f7a946 100644 --- a/iris/normalizer_test.py +++ b/iris/normalizer_test.py @@ -21,11 +21,11 @@ class BufferTest(absltest.TestCase): def test_meanstdbuffer(self): - buffer = normalizer.MeanStdBuffer((1)) + buffer = normalizer.MeanStdBuffer((1,)) buffer.push(np.asarray(10.0)) buffer.push(np.asarray(11.0)) - new_buffer = normalizer.MeanStdBuffer((1)) + new_buffer = normalizer.MeanStdBuffer((1,)) new_buffer.data = buffer.data self.assertEqual(new_buffer._std, buffer._std) @@ -145,7 +145,7 @@ def test_mean_std_buffer_empty_merge(self): self.assertEqual(mean_std_buffer._data['n'], 0) def test_mean_std_buffer_scalar(self): - mean_std_buffer = normalizer.MeanStdBuffer((1)) + mean_std_buffer = normalizer.MeanStdBuffer((1,)) mean_std_buffer.push(np.asarray(10.0)) self.assertEqual(mean_std_buffer._std, 1.0) # First value is always 1.0. @@ -154,10 +154,10 @@ def test_mean_std_buffer_scalar(self): np.testing.assert_almost_equal(mean_std_buffer._std, np.sqrt(0.5)) def test_mean_std_buffer_reject_infinity_on_merge(self): - mean_std_buffer = normalizer.MeanStdBuffer((1)) + mean_std_buffer = normalizer.MeanStdBuffer((1,)) mean_std_buffer.push(np.asarray(10.0)) - infinty_buffer = normalizer.MeanStdBuffer((1)) + infinty_buffer = normalizer.MeanStdBuffer((1,)) infinty_buffer.push(np.asarray(np.inf)) mean_std_buffer.merge(infinty_buffer.data) diff --git a/iris/policies/base_policy.py b/iris/policies/base_policy.py index 69ffd7c..b3a3889 100644 --- a/iris/policies/base_policy.py +++ b/iris/policies/base_policy.py @@ -14,6 +14,7 @@ """Policy class for computing action from weights and observation vector.""" +import abc from typing import Dict, Union import gym @@ -21,7 +22,7 @@ import numpy as np -class BasePolicy(object): +class BasePolicy(abc.ABC): """Base policy class for reinforcement learning.""" def __init__(self, ob_space: gym.Space, ac_space: gym.Space) -> None: @@ -55,23 +56,27 @@ def set_iteration(self, value: int | None): self._iteration = value def update_weights(self, new_weights: np.ndarray) -> None: + """Updates the flat weights vector.""" self._weights[:] = new_weights[:] def get_weights(self) -> np.ndarray: + """Returns the flat weights vector.""" return self._weights def get_representation_weights(self): + """Returns the flat representation weights vector.""" return self._representation_weights def update_representation_weights( self, new_representation_weights: np.ndarray) -> None: + """Updates the flat representation weights vector.""" self._representation_weights[:] = new_representation_weights[:] def reset(self): + """Resets the internal policy state.""" pass + @abc.abstractmethod def act(self, ob: Union[np.ndarray, Dict[str, np.ndarray]] ) -> Union[np.ndarray, Dict[str, np.ndarray]]: """Maps the observation to action.""" - raise NotImplementedError( - "Should be implemented in derived classes for specific policies.") diff --git a/iris/policies/nas_policy.py b/iris/policies/nas_policy.py index 46c5d47..0307947 100644 --- a/iris/policies/nas_policy.py +++ b/iris/policies/nas_policy.py @@ -25,7 +25,7 @@ import pyglove as pg -class PyGlovePolicy(abc.ABC, base_policy.BasePolicy): +class PyGlovePolicy(base_policy.BasePolicy): """Base class for all policies involving NAS search.""" @abc.abstractmethod @@ -42,40 +42,49 @@ def dna_spec(self) -> pg.DNASpec: class NumpyTopologyPolicy(PyGlovePolicy): """Parent class for numpy-based policies.""" - def __init__(self, - ob_space: gym.Space, - ac_space: gym.Space, - hidden_layer_sizes: Sequence[int], - seed: int = 0, - **kwargs): - base_policy.BasePolicy.__init__(self, ob_space, ac_space) + def __init__( + self, + ob_space: gym.Space, + ac_space: gym.Space, + hidden_layer_sizes: Sequence[int], + seed: int = 0, + **kwargs + ): + super().__init__(ob_space, ac_space) self._hidden_layer_sizes = hidden_layer_sizes - self._total_nb_nodes = sum( - self._hidden_layer_sizes) + self._ob_dim + self._ac_dim - self._all_layer_sizes = [self._ob_dim] + list( - self._hidden_layer_sizes) + [self._ac_dim] + self._total_nb_nodes = ( + sum(self._hidden_layer_sizes) + self._ob_dim + self._ac_dim + ) + self._all_layer_sizes = ( + [self._ob_dim] + list(self._hidden_layer_sizes) + [self._ac_dim] + ) self._total_weight_parameters = self._total_nb_nodes**2 self._total_bias_parameters = self._total_nb_nodes - self._total_nb_parameters = self._total_weight_parameters + self._total_bias_parameters + self._total_nb_parameters = ( + self._total_weight_parameters + self._total_bias_parameters + ) np.random.seed(seed) self._weights = np.random.uniform( - low=-1.0, high=1.0, size=(self._total_nb_nodes, self._total_nb_nodes)) + low=-1.0, high=1.0, size=(self._total_nb_nodes, self._total_nb_nodes) + ) self._biases = np.random.uniform( - low=-1.0, high=1.0, size=self._total_nb_nodes) + low=-1.0, high=1.0, size=self._total_nb_nodes + ) self._edge_dict = {} - def act(self, ob: Union[np.ndarray, Dict[str, np.ndarray]] - ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def act( + self, ob: Union[np.ndarray, Dict[str, np.ndarray]] + ) -> Union[np.ndarray, Dict[str, np.ndarray]]: ob = utils.flatten(self._ob_space, ob) values = [0.0] * self._total_nb_nodes for i in range(self._ob_dim): values[i] = ob[i] for i in range(self._total_nb_nodes): - if ((i > self._ob_dim) and (i < self._total_nb_nodes - self._ac_dim)): + if (i > self._ob_dim) and (i < self._total_nb_nodes - self._ac_dim): values[i] = np.tanh(values[i] + self._biases[i]) if i in self._edge_dict: j_list = self._edge_dict[i] diff --git a/requirements-rl.txt b/requirements-rl.txt index 3033d86..2878443 100644 --- a/requirements-rl.txt +++ b/requirements-rl.txt @@ -9,6 +9,3 @@ jax # Use latest version. jaxlib # Use latest version. flax # Use latest version. tensorflow # TODO(team): Resolve version conflicts. - -# Configuration + Experimentation -gin-config>=0.5.0 \ No newline at end of file