diff --git a/dopamine/jax/networks.py b/dopamine/jax/networks.py index 85e6ca60..d6fb30b0 100644 --- a/dopamine/jax/networks.py +++ b/dopamine/jax/networks.py @@ -50,6 +50,85 @@ def preprocess_atari_inputs(x): identity_preprocess_fn = lambda x: x +@gin.configurable +class Stack(nn.Module): + """Stack of pooling and convolutional blocks with residual connections.""" + num_ch: int + num_blocks: int + use_max_pooling: bool = True + + @nn.compact + def __call__(self, x): + initializer = nn.initializers.xavier_uniform() + conv_out = nn.Conv( + features=self.num_ch, + kernel_size=(3, 3), + strides=1, + kernel_init=initializer, + padding='SAME')( + x) + if self.use_max_pooling: + conv_out = nn.max_pool( + conv_out, window_shape=(3, 3), padding='SAME', strides=(2, 2)) + + for _ in range(self.num_blocks): + block_input = conv_out + conv_out = nn.relu(conv_out) + conv_out = nn.Conv(features=self.num_ch, kernel_size=(3, 3), + strides=1, padding='SAME')(conv_out) + conv_out = nn.relu(conv_out) + conv_out = nn.Conv(features=self.num_ch, kernel_size=(3, 3), + strides=1, padding='SAME')(conv_out) + conv_out += block_input + + return conv_out + + +@gin.configurable +class ImpalaEncoder(nn.Module): + """Impala Network which also outputs penultimate representation layers.""" + nn_scale: int = 1 + stack_sizes: Tuple[int, ...] = (16, 32, 32) + num_blocks: int = 2 + + def setup(self): + self._stacks = [ + Stack(num_ch=stack_size * self.nn_scale, + num_blocks=self.num_blocks) for stack_size in self.stack_sizes + ] + + @nn.compact + def __call__(self, x): + for stack in self._stacks: + x = stack(x) + return nn.relu(x) + + +### DQN Network with ImpalaEncoder ### +@gin.configurable +class ImpalaDQNNetwork(nn.Module): + """The convolutional network used to compute the agent's Q-values.""" + num_actions: int + inputs_preprocessed: bool = False + nn_scale: int = 1 + + def setup(self): + self.encoder = ImpalaEncoder(nn_scale=self.nn_scale) + + @nn.compact + def __call__(self, x): + initializer = nn.initializers.xavier_uniform() + if not self.inputs_preprocessed: + x = preprocess_atari_inputs(x) + x = self.encoder(x) + x = x.reshape((-1)) # flatten + x = nn.Dense(features=512, kernel_init=initializer)(x) + x = nn.relu(x) + q_values = nn.Dense(features=self.num_actions, + kernel_init=initializer)(x) + return atari_lib.DQNNetworkType(q_values) + + ### DQN Networks ### @gin.configurable class NatureDQNNetwork(nn.Module): diff --git a/dopamine/labs/atari_100k/atari_100k_rainbow_agent.py b/dopamine/labs/atari_100k/atari_100k_rainbow_agent.py index 0f52f4be..caf06206 100644 --- a/dopamine/labs/atari_100k/atari_100k_rainbow_agent.py +++ b/dopamine/labs/atari_100k/atari_100k_rainbow_agent.py @@ -14,6 +14,7 @@ # limitations under the License. """Atari 100k rainbow agent with support for data augmentation.""" +import copy import functools from absl import logging @@ -22,9 +23,55 @@ import gin import jax import jax.numpy as jnp +import numpy as onp import tensorflow as tf +@functools.partial(jax.jit, static_argnums=(0, 4, 5, 6, 7, 8, 10, 11)) +def select_action( + network_def, + params, + state, + rng, + num_actions, + eval_mode, + epsilon_eval, + epsilon_train, + epsilon_decay_period, + training_steps, + min_replay_history, + epsilon_fn, + support, +): + """Select an action from the set of available actions.""" + epsilon = jnp.where( + eval_mode, + epsilon_eval, + epsilon_fn( + epsilon_decay_period, + training_steps, + min_replay_history, + epsilon_train, + ), + ) + + rng, rng1, rng2, rng3 = jax.random.split(rng, num=4) + + @functools.partial(jax.vmap, in_axes=(0, 0), axis_name='batch') + def q_function(state, key): + q_values = network_def.apply( + params, state, key=key, eval_mode=eval_mode, support=support + ).q_values + return q_values + + q_values = q_function(state, jax.random.split(rng2, state.shape[0])) + + best_actions = jnp.argmax(q_values, axis=-1) + random_actions = jax.random.randint(rng3, (state.shape[0],), 0, num_actions) + p = jax.random.uniform(rng1, shape=(state.shape[0],)) + return rng, jnp.where(p <= epsilon, random_actions, best_actions) + + ############################ Data Augmentation ############################ @@ -118,6 +165,7 @@ def __init__(self, self.train_preprocess_fn = functools.partial( preprocess_inputs_with_augmentation, data_augmentation=data_augmentation) + self.state_shape = self.state.shape def _training_step_update(self): """Gradient update during every training step.""" @@ -156,3 +204,144 @@ def _training_step_update(self): step=self.training_steps) self.summary_writer.flush() + def step(self, reward=None, observation=None): + """Selects an action, and optionally records a transition and trains. + + If `reward` or `observation` is None, the agent's state will _not_ be + updated and nothing will be written to the buffer. The user must call + `log_transition` themselves in this case. + + Args: + reward: Optional reward to log. + observation: Optional observation to log. Must call `log_transition` later + if not passed here. + + Returns: + Selected action. + """ + if reward is not None and observation is not None: + self._last_observation = self._observation + self._record_observation(observation) + if not self.eval_mode: + self._store_transition( + self._last_observation, self.action, reward, False + ) + + if not self.eval_mode: + self._train_step() + + state = self.preprocess_fn(self.state) + self._rng, action = select_action( + self.network_def, + self.online_params, + state, + self._rng, + self.num_actions, + self.eval_mode, + self.epsilon_eval, + self.epsilon_train, + self.epsilon_decay_period, + self.training_steps, + self.min_replay_history, + self.epsilon_fn, + self._support, + ) + self.action = onp.asarray(action) + return self.action + + def _reset_state(self, n_envs=None): + """Resets the agent state by filling it with zeros.""" + if n_envs is None: + self.state = onp.zeros((1, *self.state_shape)) + else: + self.state = onp.zeros((n_envs, *self.state_shape)) + + def _record_observation(self, observation): + """Records an observation and update state. + + Extracts a frame from the observation vector and overwrites the oldest + frame in the state buffer. + + Args: + observation: numpy array, an observation from the environment. + """ + # Set current observation. We do the reshaping to handle environments + # without frame stacking. + observation = observation.squeeze(-1) + if len(observation.shape) == len(self.observation_shape): + self._observation = onp.reshape(observation, self.observation_shape) + else: + self._observation = onp.reshape( + observation, (observation.shape[0], *self.observation_shape) + ) + # Swap out the oldest frame with the current frame. + self.state = onp.roll(self.state, -1, axis=-1) + self.state[..., -1] = self._observation + + def reset_all(self, new_obs): + """Resets the agent state by filling it with zeros.""" + n_envs = new_obs.shape[0] + self.state = onp.zeros((n_envs, *self.state_shape)) + self._record_observation(new_obs) + + def reset_one(self, env_id): + self.state[env_id].fill(0) + + def delete_one(self, env_id): + self.state = onp.concatenate( + [self.state[:env_id], self.state[env_id + 1 :]], 0 + ) + + def cache_train_state(self): + self.training_state = ( + copy.deepcopy(self.state), + copy.deepcopy(self._last_observation), + copy.deepcopy(self._observation), + ) + + def restore_train_state(self): + (self.state, self._last_observation, self._observation) = ( + self.training_state + ) + + def log_transition(self, observation, action, reward, terminal, episode_end): + self._last_observation = self._observation + self._record_observation(observation) + + if not self.eval_mode: + self._store_transition( + self._last_observation, + action, + reward, + terminal, + episode_end=episode_end, + ) + + def _store_transition( + self, + last_observation, + action, + reward, + is_terminal, + *args, + priority=None, + episode_end=False + ): + """Stores a transition when in training mode.""" + is_prioritized = hasattr(self._replay, 'sum_tree') + if is_prioritized and priority is None: + priority = onp.ones_like(reward) + if self._replay_scheme == 'prioritized': + priority *= self._replay.sum_tree.max_recorded_priority + + to_store = (last_observation, action, reward, is_terminal, *args) + to_store = (onp.asarray(x) for x in to_store) + if not hasattr(self._replay, '_n_envs'): + to_store = (onp.squeeze(x) for x in to_store) + priority = onp.squeeze(priority) + elif hasattr(self._replay, '_n_envs') and not reward.shape: + to_store = (onp.expand_dims(x, 0) if not (x.shape and x.shape[0] == 1) + else x for x in to_store) + priority = onp.expand_dims(priority, 0) + if not self.eval_mode: + self._replay.add(*to_store, priority=priority, episode_end=episode_end) diff --git a/dopamine/labs/atari_100k/atari_100k_runner.py b/dopamine/labs/atari_100k/atari_100k_runner.py new file mode 100755 index 00000000..25adb338 --- /dev/null +++ b/dopamine/labs/atari_100k/atari_100k_runner.py @@ -0,0 +1,576 @@ +# coding=utf-8 +# Copyright 2023 The Dopamine Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Runner for evaluating using a fixed number of episodes.""" + +import functools +import os +import sys +import time + +from absl import logging +from dopamine.discrete_domains import atari_lib +from dopamine.discrete_domains import iteration_statistics +from dopamine.discrete_domains import run_experiment +from dopamine.labs.atari_100k import normalization_utils +import gin +import jax +import numpy as np +import tensorflow as tf + + +def create_env_wrapper(create_env_fn): + + def inner_create(*args, **kwargs): + env = create_env_fn(*args, **kwargs) + env.cum_length = 0 + env.cum_reward = 0 + return env + + return inner_create + + +@gin.configurable +class DataEfficientAtariRunner(run_experiment.Runner): + """Runner for evaluating using a fixed number of episodes rather than steps. + + Also restricts data collection to a strict cap, + following conventions in data-efficient RL research. + """ + + def __init__( + self, + base_dir, + create_agent_fn, + game_name=None, + create_environment_fn=atari_lib.create_atari_environment, + num_eval_episodes=100, + max_noops=30, + parallel_eval=True, + num_eval_envs=100, + num_train_envs=1, + eval_one_to_one=True, + log_normalized_scores=False, + ): + """Specify the number of evaluation episodes.""" + if game_name is not None: + create_environment_fn = functools.partial( + create_environment_fn, game_name=game_name + ) + self.game_name = game_name.lower().replace('_', '').replace(' ', '') + + if log_normalized_scores: + raise ValueError( + 'Game name must not be None if logging normalized scores.' + ) + super().__init__( + base_dir, create_agent_fn, create_environment_fn=create_environment_fn) + + self._num_iterations = int(self._num_iterations) + self._start_iteration = int(self._start_iteration) + + self._num_eval_episodes = num_eval_episodes + logging.info('Num evaluation episodes: %d', num_eval_episodes) + self._evaluation_steps = None + self.num_steps = 0 + self.total_steps = self._training_steps * self._num_iterations + self.create_environment_fn = create_env_wrapper(create_environment_fn) + + self.max_noops = max_noops + self.parallel_eval = parallel_eval + self.num_eval_envs = num_eval_envs + self.num_train_envs = num_train_envs + self.eval_one_to_one = eval_one_to_one + + self.train_envs = [ + self.create_environment_fn() for i in range(num_train_envs) + ] + self.train_state = None + self._agent.reset_all(self._initialize_episode(self.train_envs)) + self._agent.cache_train_state() + + self.log_normalized_scores = log_normalized_scores + + def _run_one_phase(self, + envs, + steps, + max_episodes, + statistics, + run_mode_str, + needs_reset=False, + one_to_one=False, + resume_state=None): + """Runs the agent/environment loop until a desired number of steps. + + We terminate precisely when the desired number of steps has been reached, + unlike some other implementations. + + Args: + envs: environments to use in this phase. + steps: int, how many steps to run in this phase (or None). + max_episodes: int, maximum number of episodes to generate in this phase. + statistics: `IterationStatistics` object which records the experimental + results. + run_mode_str: str, describes the run mode for this agent. + needs_reset: bool, whether to reset all environments before starting. + one_to_one: bool, whether to precisely match each episode in + `max_episodes` to an environment in `envs`. True is faster but only + works in some situations (e.g., evaluation). + resume_state: bool, whether to have the agent resume its prior state for + the current mode. + + Returns: + Tuple containing the number of steps taken in this phase (int), the + sum of + returns (float), and the number of episodes performed (int). + """ + step_count = 0 + num_episodes = 0 + sum_returns = 0. + + (episode_lengths, episode_returns, state, envs) = self._run_parallel( + episodes=max_episodes, + envs=envs, + one_to_one=one_to_one, + needs_reset=needs_reset, + resume_state=resume_state, + max_steps=steps, + ) + + for episode_length, episode_return in zip(episode_lengths, episode_returns): + statistics.append({ + '{}_episode_lengths'.format(run_mode_str): episode_length, + '{}_episode_returns'.format(run_mode_str): episode_return + }) + if run_mode_str == 'train': + # we use one extra frame at the starting + self.num_steps += episode_length + step_count += episode_length + sum_returns += episode_return + num_episodes += 1 + sys.stdout.flush() + if self._summary_writer is not None: + with self._summary_writer.as_default(): + _ = ( + tf.summary.scalar( + 'train_episode_returns', + float(episode_return), + step=self.num_steps, + ), + ) + _ = tf.summary.scalar( + 'train_episode_lengths', + float(episode_length), + step=self.num_steps, + ) + return step_count, sum_returns, num_episodes, state, envs + + def _initialize_episode(self, envs): + """Initialization for a new episode. + + Args: + envs: Environments to initialize episodes for. + + Returns: + action: int, the initial action chosen by the agent. + """ + observations = [] + for env in envs: + initial_observation = env.reset() + if self.max_noops > 0: + self._agent._rng, rng = jax.random.split( + self._agent._rng # pylint: disable=protected-access + ) + num_noops = jax.random.randint(rng, (), 0, self.max_noops) + for _ in range(num_noops): + initial_observation, _, terminal, _ = env.step(0) + if terminal: + initial_observation = env.reset() + observations.append(initial_observation) + initial_observation = np.stack(observations, 0) + + return initial_observation + + def _run_parallel(self, + envs, + episodes=None, + max_steps=None, + one_to_one=False, + needs_reset=True, + resume_state=None): + """Executes a full trajectory of the agent interacting with the environment. + + Args: + envs: Environments to step in. + episodes: Optional int, how many episodes to run. Unbounded if None. + max_steps: Optional int, how many steps to run. Unbounded if None. + one_to_one: Bool, whether to couple each episode to an environment. + needs_reset: Bool, whether to reset environments before beginning. + resume_state: State tuple to resume. + + Returns: + The number of steps taken and the total reward. + """ + # You can't ask for 200 episodes run one-to-one on 100 envs + if one_to_one: + assert episodes is None or episodes == len(envs) + + # Create envs + live_envs = list(range(len(envs))) + + if needs_reset: + new_obs = self._initialize_episode(envs) + new_obses = np.zeros((2, len(envs), *self._agent.observation_shape, 1)) + self._agent.reset_all(new_obs) + + rewards = np.zeros((len(envs),)) + terminals = np.zeros((len(envs),)) + episode_end = np.zeros((len(envs),)) + + cum_rewards = [] + cum_lengths = [] + else: + assert resume_state is not None + (new_obses, rewards, terminals, episode_end, cum_rewards, cum_lengths) = ( + resume_state + ) + + total_steps = 0 + total_episodes = 0 + max_steps = np.inf if max_steps is None else max_steps + step = 0 + + # Keep interacting until we reach a terminal state. + while True: + live_env_index = 0 + step += 1 + episode_end.fill(0) + total_steps += len(live_envs) + actions = self._agent.step() + + # The agent may be hanging on to the previous new_obs, so we shouldn't + # change it. By alternating, we can make sure we don't end up logging + # with an offset while limiting memory usage. + new_obs = new_obses[step % 2] + + # don't want to do a for-loop since live envs may change + while live_env_index < len(live_envs): + env_id = live_envs[live_env_index] + obs, reward, done, _ = envs[env_id].step(actions[live_env_index]) + envs[env_id].cum_length += 1 + envs[env_id].cum_reward += reward + new_obs[live_env_index] = obs + rewards[live_env_index] = reward + terminals[live_env_index] = done + + if (envs[env_id].game_over or + envs[env_id].cum_length == self._max_steps_per_episode): + total_episodes += 1 + cum_rewards.append(envs[env_id].cum_reward) + cum_lengths.append(envs[env_id].cum_length) + envs[env_id].cum_length = 0 + envs[env_id].cum_reward = 0 + + log_str = ( + 'Steps executed: {} Num episodes: {} Episode length: {}' + ' Return: {} '.format( + total_steps, + len(cum_rewards), + cum_lengths[-1], + cum_rewards[-1], + ) + ) + logging.info(log_str) + self._maybe_save_single_summary(self.num_steps + total_steps, + cum_rewards[-1], cum_lengths[-1]) + + if one_to_one: + new_obses = delete_ind_from_array(new_obses, live_env_index, axis=1) + new_obs = new_obses[step % 2] + actions = delete_ind_from_array(actions, live_env_index) + rewards = delete_ind_from_array(rewards, live_env_index) + terminals = delete_ind_from_array(terminals, live_env_index) + self._agent.delete_one(live_env_index) + del live_envs[live_env_index] + live_env_index -= 1 # Go back one to make up for deleting a value. + else: + episode_end[live_env_index] = 1 + new_obs[live_env_index] = self._initialize_episode([envs[env_id]]) + self._agent.reset_one(env_id=live_env_index) + elif done: + self._agent.reset_one(env_id=live_env_index) + + live_env_index += 1 + + if self._clip_rewards: + # Perform reward clipping. + rewards = np.clip(rewards, -1, 1) + + self._agent.log_transition(new_obs, actions, rewards, terminals, + episode_end) + + if ( + not live_envs + or (max_steps is not None and total_steps > max_steps) + or (episodes is not None and total_episodes > episodes) + ): + break + + state = (new_obses, rewards, terminals, episode_end, cum_rewards, + cum_lengths) + return cum_lengths, cum_rewards, state, envs + + def _run_train_phase(self, statistics): + """Run training phase. + + Args: + statistics: `IterationStatistics` object which records the experimental + results. Note - This object is modified by this method. + + Returns: + num_episodes: int, The number of episodes run in this phase. + average_reward: float, The average reward generated in this phase. + average_steps_per_second: float, The average number of steps per + second. + """ + # Perform the training phase, during which the agent learns. + self._agent.eval_mode = False + self._agent.restore_train_state() + start_time = time.time() + ( + number_steps, + sum_returns, + num_episodes, + self.train_state, + self.train_envs, + ) = self._run_one_phase( + self.train_envs, + self._training_steps, + max_episodes=None, + statistics=statistics, + run_mode_str='train', + needs_reset=self.train_state is None, + resume_state=self.train_state, + ) + average_return = sum_returns / num_episodes if num_episodes > 0 else 0.0 + statistics.append({'train_average_return': average_return}) + time_delta = time.time() - start_time + average_steps_per_second = number_steps / time_delta + statistics.append( + {'train_average_steps_per_second': average_steps_per_second} + ) + logging.info( + 'Average undiscounted return per training episode: %.2f', average_return + ) + logging.info( + 'Average training steps per second: %.2f', average_steps_per_second + ) + + if self.log_normalized_scores: + normalized_return = normalization_utils.normalize_score( + average_return, self.game_name + ) + statistics.append({'train_average_normalized_score': normalized_return}) + logging.info( + 'Average normalized return per training episode: %.2f', + normalized_return, + ) + else: + normalized_return = None + + self._agent.cache_train_state() + return ( + num_episodes, + average_return, + average_steps_per_second, + normalized_return, + ) + + def _run_eval_phase(self, statistics): + """Run evaluation phase. + + Args: + statistics: `IterationStatistics` object which records the experimental + results. Note - This object is modified by this method. + + Returns: + num_episodes: int, The number of episodes run in this phase. + average_reward: float, The average reward generated in this phase. + """ + # Perform the evaluation phase -- no learning. + self._agent.eval_mode = True + eval_envs = [ + self.create_environment_fn() for i in range(self.num_eval_envs) + ] + _, sum_returns, num_episodes, _, _ = self._run_one_phase( + eval_envs, + steps=None, + max_episodes=self._num_eval_episodes, + statistics=statistics, + needs_reset=True, + resume_state=None, + one_to_one=self.eval_one_to_one, + run_mode_str='eval', + ) + average_return = sum_returns / num_episodes if num_episodes > 0 else 0.0 + logging.info( + 'Average undiscounted return per evaluation episode: %.2f', + average_return, + ) + statistics.append({'eval_average_return': average_return}) + if self.log_normalized_scores: + normalized_return = normalization_utils.normalize_score( + average_return, self.game_name + ) + statistics.append({'train_average_normalized_score': normalized_return}) + logging.info( + 'Average normalized return per evaluation episode: %.2f', + normalized_return, + ) + else: + normalized_return = None + return num_episodes, average_return, normalized_return + + def _run_one_iteration(self, iteration): + """Runs one iteration of agent/environment interaction.""" + statistics = iteration_statistics.IterationStatistics() + logging.info('Starting iteration %d', iteration) + ( + num_episodes_train, + average_reward_train, + average_steps_per_second, + norm_score_train, + ) = self._run_train_phase(statistics) + num_episodes_eval, average_reward_eval, norm_score_eval = ( + self._run_eval_phase(statistics) + ) + self._save_tensorboard_summaries( + iteration, + num_episodes_train, + average_reward_train, + norm_score_train, + num_episodes_eval, + average_reward_eval, + norm_score_eval, + average_steps_per_second, + ) + return statistics.data_lists + + def _maybe_save_single_summary(self, + iteration, + ep_return, + length, + save_if_eval=False): + prefix = 'Train/' if not self._agent.eval_mode else 'Eval/' + if not self._agent.eval_mode or save_if_eval: + with self._summary_writer.as_default(): + tf.summary.scalar(prefix + 'EpisodeLength', length, step=iteration) + tf.summary.scalar(prefix + 'EpisodeReturn', ep_return, step=iteration) + if self.log_normalized_scores: + normalized_score = normalization_utils.normalize_score( + ep_return, self.game_name + ) + tf.summary.scalar( + prefix + 'EpisodeNormalizedScore', + normalized_score, + step=iteration, + ) + + def _save_tensorboard_summaries( + self, + iteration, + num_episodes_train, + average_reward_train, + norm_score_train, + num_episodes_eval, + average_reward_eval, + norm_score_eval, + average_steps_per_second, + ): + """Save statistics as tensorboard summaries. + + Args: + iteration: int, The current iteration number. + num_episodes_train: int, number of training episodes run. + average_reward_train: float, The average training reward. + norm_score_train: float, average training normalized score. + num_episodes_eval: int, number of evaluation episodes run. + average_reward_eval: float, The average evaluation reward. + norm_score_eval: float, average eval normalized score. + average_steps_per_second: float, The average number of steps per second. + """ + with self._summary_writer.as_default(): + tf.summary.scalar('Train/NumEpisodes', num_episodes_train, step=iteration) + tf.summary.scalar( + 'Train/AverageReturns', average_reward_train, step=iteration) + tf.summary.scalar( + 'Train/AverageStepsPerSecond', + average_steps_per_second, + step=iteration) + tf.summary.scalar('Eval/NumEpisodes', num_episodes_eval, step=iteration) + tf.summary.scalar( + 'Eval/AverageReturns', average_reward_eval, step=iteration) + if self.log_normalized_scores: + tf.summary.scalar( + 'Train/AverageNormalizedScore', norm_score_train, step=iteration + ) + tf.summary.scalar( + 'Eval/NormalizedScore', norm_score_eval, step=iteration + ) + + def run_experiment(self): + """Runs a full experiment, spread over multiple iterations.""" + logging.info('Beginning training...') + if self._num_iterations <= self._start_iteration: + logging.warning('num_iterations (%d) < start_iteration(%d)', + self._num_iterations, self._start_iteration) + return + + for iteration in range(self._start_iteration, self._num_iterations): + statistics = self._run_one_iteration(iteration) + self._log_experiment(iteration, statistics) + self._checkpoint_experiment(iteration) + self._summary_writer.flush() + + +@gin.configurable +class LoggedDataEfficientAtariRunner(DataEfficientAtariRunner): + """Runner for loading/saving replay data.""" + + def __init__(self, + base_dir, + create_agent_fn, + load_replay_dir=None, + save_replay=False): + super().__init__(base_dir, create_agent_fn) + self._load_replay_dir = load_replay_dir + self._save_replay = save_replay + logging.info('Load fixed replay from directory: %s', load_replay_dir) + logging.info('Save replay: %s', save_replay) + + def run_experiment(self): + """Runs a full experiment, spread over multiple iterations.""" + if self._load_replay_dir is not None: + self._agent.load_fixed_replay(self._load_replay_dir) + super().run_experiment() + if self._save_replay: + save_replay_dir = os.path.join(self._base_dir, 'replay_logs') + self._agent.save_replay(save_replay_dir) + + +def delete_ind_from_array(array, ind, axis=0): + start = tuple(([slice(None)] * axis) + [slice(0, ind)]) + end = tuple(([slice(None)] * axis) + [slice(ind + 1, array.shape[axis] + 1)]) + tensor = np.concatenate([array[start], array[end]], axis) + return tensor diff --git a/dopamine/labs/atari_100k/configs/SPR.gin b/dopamine/labs/atari_100k/configs/SPR.gin new file mode 100644 index 00000000..7c0737b0 --- /dev/null +++ b/dopamine/labs/atari_100k/configs/SPR.gin @@ -0,0 +1,64 @@ + +# Self-Predictive Representations (Schwarzer et al, 2021) +# Major changes from standard Rainbow other than the SPR loss itself include: +# * No separate target network (update period 1), +# * 10-step returns (instead of 3) +# * Data augmentation (as in DrQ, Kostrikov et al 2020) +# * Replay every step, and two train updates per step +# * total replay ratio increased by factor of 8 +# These are roughly a hybrid of DrQ and Rainbow's hyperparameters. + +import dopamine.jax.agents.dqn.dqn_agent +import dopamine.jax.networks +import dopamine.discrete_domains.checkpointer +import dopamine.discrete_domains.gym_lib +import dopamine.discrete_domains.run_experiment +import dopamine.replay_memory.prioritized_replay_buffer +import dopamine.labs.atari_100k.replay_memory.subsequence_replay_buffer +import dopamine.labs.atari_100k.spr_networks +import dopamine.labs.atari_100k.spr_agent +import dopamine.labs.atari_100k.atari_100k_rainbow_agent + +# Parameters specific to DrQ are higlighted by comments +JaxDQNAgent.gamma = 0.99 +JaxDQNAgent.update_horizon = 10 # DrQ (instead of 3) +JaxDQNAgent.min_replay_history = 2000 # DrQ (instead of 20000) +JaxDQNAgent.update_period = 1 # DrQ (rather than 4) +JaxDQNAgent.target_update_period = 1 # DrQ (rather than 8000) +JaxDQNAgent.epsilon_train = 0.00 +JaxDQNAgent.epsilon_eval = 0.001 +JaxDQNAgent.epsilon_decay_period = 2001 # DrQ +JaxDQNAgent.optimizer = 'adam' + +JaxFullRainbowAgent.noisy = True +JaxFullRainbowAgent.dueling = True +JaxFullRainbowAgent.double_dqn = True +JaxFullRainbowAgent.distributional = True +JaxFullRainbowAgent.num_atoms = 51 +JaxFullRainbowAgent.num_updates_per_train_step = 2 +JaxFullRainbowAgent.replay_scheme = 'prioritized' +JaxFullRainbowAgent.network = @dopamine.labs.atari_100k.spr_networks.SPRNetwork +Atari100kRainbowAgent.data_augmentation = True +dopamine.labs.atari_100k.spr_agent.SPRAgent.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon +dopamine.labs.atari_100k.spr_agent.SPRAgent.spr_weight = 5 +dopamine.labs.atari_100k.spr_agent.SPRAgent.jumps = 5 + +# Note these parameters are from DER (van Hasselt et al, 2019) +create_optimizer.learning_rate = 0.0001 +create_optimizer.eps = 0.00015 + +atari_lib.create_atari_environment.game_name = 'Pong' +# Atari 100K benchmark doesn't use sticky actions. +atari_lib.create_atari_environment.sticky_actions = False +AtariPreprocessing.terminal_on_life_loss = True +Runner.num_iterations = 10 +Runner.training_steps = 10000 # agent steps +MaxEpisodeEvalRunner.num_eval_episodes = 100 # agent episodes +Runner.max_steps_per_episode = 27000 # agent steps + +dopamine.labs.atari_100k.replay_memory.subsequence_replay_buffer.PrioritizedJaxSubsequenceParallelEnvReplayBuffer.replay_capacity = 1000000 +dopamine.labs.atari_100k.replay_memory.subsequence_replay_buffer.PrioritizedJaxSubsequenceParallelEnvReplayBuffer.batch_size = 32 +dopamine.labs.atari_100k.replay_memory.subsequence_replay_buffer.JaxSubsequenceParallelEnvReplayBuffer.replay_capacity = 1000000 +dopamine.labs.atari_100k.replay_memory.subsequence_replay_buffer.JaxSubsequenceParallelEnvReplayBuffer.batch_size = 32 + +Checkpointer.keep_every = 1 diff --git a/dopamine/labs/atari_100k/normalization_utils.py b/dopamine/labs/atari_100k/normalization_utils.py new file mode 100644 index 00000000..42f352ea --- /dev/null +++ b/dopamine/labs/atari_100k/normalization_utils.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright 2023 The Dopamine Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Human and random Atari scores, and a function to normalization by them.""" + +ATARI_HUMAN_SCORES = { + 'alien': 7127.7, + 'amidar': 1719.5, + 'assault': 742.0, + 'asterix': 8503.3, + 'asteroids': 47388.7, + 'atlantis': 29028.1, + 'bankheist': 753.1, + 'battlezone': 37187.5, + 'beamrider': 16926.5, + 'berzerk': 2630.4, + 'bowling': 160.7, + 'boxing': 12.1, + 'breakout': 30.5, + 'centipede': 12017.0, + 'choppercommand': 7387.8, + 'crazyclimber': 35829.4, + 'demonattack': 1971.0, + 'doubledunk': -16.4, + 'enduro': 860.5, + 'fishingderby': -38.7, + 'freeway': 29.6, + 'frostbite': 4334.7, + 'gopher': 2412.5, + 'gravitar': 3351.4, + 'hero': 30826.4, + 'icehockey': 0.9, + 'jamesbond': 302.8, + 'kangaroo': 3035.0, + 'krull': 2665.5, + 'kungfumaster': 22736.3, + 'montezumarevenge': 4753.3, + 'mspacman': 6951.6, + 'namethisgame': 8049.0, + 'phoenix': 7242.6, + 'pitfall': 6463.7, + 'pong': 14.6, + 'privateeye': 69571.3, + 'qbert': 13455.0, + 'riverraid': 17118.0, + 'roadrunner': 7845.0, + 'robotank': 11.9, + 'seaquest': 42054.7, + 'skiing': -4336.9, + 'solaris': 12326.7, + 'spaceinvaders': 1668.7, + 'stargunner': 10250.0, + 'tennis': -8.3, + 'timepilot': 5229.2, + 'tutankham': 167.6, + 'upndown': 11693.2, + 'venture': 1187.5, + 'videopinball': 17667.9, + 'wizardofwor': 4756.5, + 'yarsrevenge': 54576.9, + 'zaxxon': 9173.3, +} + +ATARI_RANDOM_SCORES = { + 'alien': 227.8, + 'amidar': 5.8, + 'assault': 222.4, + 'asterix': 210.0, + 'asteroids': 719.1, + 'atlantis': 12850.0, + 'bankheist': 14.2, + 'battlezone': 2360.0, + 'beamrider': 363.9, + 'berzerk': 123.7, + 'bowling': 23.1, + 'boxing': 0.1, + 'breakout': 1.7, + 'centipede': 2090.9, + 'choppercommand': 811.0, + 'crazyclimber': 10780.5, + 'defender': 2874.5, + 'demonattack': 152.1, + 'doubledunk': -18.6, + 'enduro': 0.0, + 'fishingderby': -91.7, + 'freeway': 0.0, + 'frostbite': 65.2, + 'gopher': 257.6, + 'gravitar': 173.0, + 'hero': 1027.0, + 'icehockey': -11.2, + 'jamesbond': 29.0, + 'kangaroo': 52.0, + 'krull': 1598.0, + 'kungfumaster': 258.5, + 'montezumarevenge': 0.0, + 'mspacman': 307.3, + 'namethisgame': 2292.3, + 'phoenix': 761.4, + 'pitfall': -229.4, + 'pong': -20.7, + 'privateeye': 24.9, + 'qbert': 163.9, + 'riverraid': 1338.5, + 'roadrunner': 11.5, + 'robotank': 2.2, + 'seaquest': 68.4, + 'skiing': -17098.1, + 'solaris': 1236.3, + 'spaceinvaders': 148.0, + 'stargunner': 664.0, + 'surround': -10.0, + 'tennis': -23.8, + 'timepilot': 3568.0, + 'tutankham': 11.4, + 'upndown': 533.4, + 'venture': 0.0, + 'videopinball': 0.0, + 'wizardofwor': 563.5, + 'yarsrevenge': 3092.9, + 'zaxxon': 32.5, +} + + +def normalize_score(ret, game): + return (ret - ATARI_RANDOM_SCORES[game]) / ( + ATARI_HUMAN_SCORES[game] - ATARI_RANDOM_SCORES[game] + ) diff --git a/dopamine/labs/atari_100k/replay_memory/deterministic_sum_tree.py b/dopamine/labs/atari_100k/replay_memory/deterministic_sum_tree.py new file mode 100755 index 00000000..82ed5045 --- /dev/null +++ b/dopamine/labs/atari_100k/replay_memory/deterministic_sum_tree.py @@ -0,0 +1,173 @@ +# coding=utf-8 +# Copyright 2023 The Dopamine Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A sum tree data structure that uses JAX for controlling randomness.""" + +import functools + +from dopamine.replay_memory import sum_tree +import jax +from jax import numpy as jnp +import numpy as np + + +@functools.partial(jax.jit) +def step(i, args): # pylint: disable=unused-argument + query_value, index, nodes = args + left_child = index * 2 + 1 + left_sum = nodes[left_child] + index = jax.lax.cond(query_value < left_sum, lambda x: x, lambda x: x + 1, + left_child) + query_value = jax.lax.cond(query_value < left_sum, lambda x: x, + lambda x: x - left_sum, query_value) + return query_value, index, nodes + + +@functools.partial(jax.jit) +@functools.partial(jax.vmap, in_axes=(None, None, 0, None, None)) +def parallel_stratified_sample(rng, nodes, i, n, depth): + """Sample a batch of indices in parallel with vmap. + + Args: + rng: a JAX prng key + nodes: the sum tree's storage array + i: the ID of the current search (index in batch) + n: batch size + depth: the depth of the sum tree + + Returns: + indices: a vector of indices, of shape (n,). + """ + rng = jax.random.fold_in(rng, i) + total_priority = nodes[0] + upper_bound = (i + 1) / n + lower_bound = i / n + query = jax.random.uniform(rng, minval=lower_bound, maxval=upper_bound) + _, index, _ = jax.lax.fori_loop(0, depth, step, + (query * total_priority, 0, nodes)) + return index + + +class DeterministicSumTree(sum_tree.SumTree): + """A sum tree data structure for storing replay priorities. + + In contrast to the original implementation, this uses JAX for handling + randomness, which allows us to reproduce the same results when using the + same seed. + """ + + def __init__(self, capacity): + """Creates the sum tree data structure for the given replay capacity. + + Args: + capacity: int, the maximum number of elements that can be stored in + this data structure. + + Raises: + ValueError: If requested capacity is not positive. + """ + assert isinstance(capacity, int) + if capacity <= 0: + raise ValueError( + 'Sum tree capacity should be positive. Got: {}'.format(capacity)) + + self.depth = int(np.ceil(np.log2(capacity))) + self.low_idx = (2**self.depth) - 1 # pri_idx + low_idx -> tree_idx + self.high_idx = capacity + self.low_idx + self.nodes = np.zeros(2**(self.depth + 1) - 1) # Double precision. + self.capacity = capacity + + self.highest_set = 0 + + self.max_recorded_priority = 1.0 + + def _total_priority(self): + """Returns the sum of all priorities stored in this sum tree. + + Returns: + float, sum of priorities stored in this sum tree. + """ + return self.nodes[0] + + def sample(self, rng, query_value=None): + """Samples an element from the sum tree.""" + rng = jax.device_put(rng, jax.devices('cpu')[0]) + nodes = jax.device_put(jnp.asarray(self.nodes), jax.devices('cpu')[0]) + query_value = ( + jax.random.uniform(rng) if query_value is None else query_value) + query_value *= self._total_priority() + + _, index, _ = jax.lax.fori_loop(0, self.depth, step, + (query_value, 0, nodes)) + + return np.minimum(index - self.low_idx, self.highest_set) + + def stratified_sample(self, batch_size, rng): + """Performs stratified sampling using the sum tree.""" + if self._total_priority() == 0.0: + raise ValueError('Cannot sample from an empty sum tree.') + + rng = jax.device_put(rng, jax.devices('cpu')[0]) + nodes = jax.device_put(jnp.asarray(self.nodes), jax.devices('cpu')[0]) + indices = parallel_stratified_sample( + rng, nodes, np.arange(batch_size), batch_size, self.depth + ) + return np.minimum(indices - self.low_idx, self.highest_set) + + def get(self, node_index): + """Returns the value of the leaf node corresponding to the index. + + Args: + node_index: The index of the leaf node. + + Returns: + The value of the leaf node. + """ + return self.nodes[node_index + self.low_idx] + + def reset_priorities(self): + for i in range(self.highest_set): + self.set(i, self.max_recorded_priority) + + def set(self, node_index, value): + """Sets the value of a leaf node and updates internal nodes accordingly. + + This operation takes O(log(capacity)). + Args: + node_index: int, the index of the leaf node to be updated. + value: float, the value which we assign to the node. This value must + be nonnegative. Setting value = 0 will cause the element to never + be sampled. + + Raises: + ValueError: If the given value is negative. + """ + if value < 0.0: + raise ValueError( + 'Sum tree values should be nonnegative. Got {}'.format(value)) + self.highest_set = max(node_index, self.highest_set) + node_index = node_index + self.low_idx + self.max_recorded_priority = max(value, self.max_recorded_priority) + + delta_value = value - self.nodes[node_index] + + # Now traverse back the tree, adjusting all sums along the way. + for _ in reversed(range(self.depth)): + # Note: Adding a delta leads to some tolerable numerical inaccuracies. + self.nodes[node_index] += delta_value + node_index = (node_index - 1) // 2 + + self.nodes[node_index] += delta_value + assert node_index == 0, ('Sum tree traversal failed, final node index ' + 'is not 0.') diff --git a/dopamine/labs/atari_100k/replay_memory/subsequence_replay_buffer.py b/dopamine/labs/atari_100k/replay_memory/subsequence_replay_buffer.py new file mode 100755 index 00000000..de14210e --- /dev/null +++ b/dopamine/labs/atari_100k/replay_memory/subsequence_replay_buffer.py @@ -0,0 +1,997 @@ +# coding=utf-8 +# Copyright 2023 The Dopamine Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A subsequence replay buffer that supports recurrent and standard algorithms.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import gzip +import math +import os +import pickle + +from absl import logging +from dopamine.labs.atari_100k.replay_memory import deterministic_sum_tree as sum_tree +import gin +import jax +from jax import numpy as jnp +import numpy as np +import tensorflow as tf + + +# Defines a type describing part of the tuple returned by the replay +# memory. Each element of the tuple is a tensor of shape [batch, ...] where +# ... is defined the 'shape' field of ReplayElement. The tensor type is +# given by the 'type' field. The 'name' field is for convenience and ease of +# debugging. +ReplayElement = ( + collections.namedtuple('shape_type', ['name', 'shape', 'type'])) + +# A prefix that can not collide with variable names for checkpoint files. +STORE_FILENAME_PREFIX = '$store$_' + +# This constant determines how many iterations a checkpoint is kept for. +CHECKPOINT_DURATION = 4 + + +def modulo_range(start, length, modulo): + for i in range(length): + yield (start + i) % modulo + + +def invalid_range(cursor, replay_capacity, stack_size, update_horizon): + """Returns a array with the indices of cursor-related invalid transitions. + + There are update_horizon + stack_size invalid indices: + - The update_horizon indices before the cursor, because we do not have a + valid N-step transition (including the next state). + - The stack_size indices on or immediately after the cursor. + If N = update_horizon, K = stack_size, and the cursor is at c, invalid + indices are: + c - N, c - N + 1, ..., c, c + 1, ..., c + K - 1. + It handles special cases in a circular buffer in the beginning and the end. + Args: + cursor: int, the position of the cursor. + replay_capacity: int, the size of the replay memory. + stack_size: int, the size of the stacks returned by the replay memory. + update_horizon: int, the agent's update horizon. + + Returns: + np.array of size stack_size with the invalid indices. + """ + assert cursor < replay_capacity + return np.array([(cursor - update_horizon + i) % replay_capacity + for i in range(stack_size + update_horizon)]) + + +@gin.configurable +class JaxSubsequenceParallelEnvReplayBuffer(object): + """A simple out-of-graph Replay Buffer. + + Stores transitions, state, action, reward, next_state, terminal (and any + extra contents specified) in a circular buffer and provides a uniform + transition sampling function. + When the states consist of stacks of observations storing the states is + inefficient. This class writes observations and constructs the stacked + states at sample time. + This class supports multiple parallel environments and returns + subsequences by default. + Attributes: + add_count: int, counter of how many transitions have been added (including + the blank ones at the beginning of an episode). + invalid_range: np.array, an array with the indices of cursor-related invalid + transitions + total_steps: int, total number of transitions added across all environments. + """ + + def __init__( + self, + observation_shape, + stack_size, + replay_capacity, + batch_size, + subseq_len, + n_envs=1, + update_horizon=1, + gamma=0.99, + max_sample_attempts=1000, + use_next_state=True, + extra_storage_types=None, + observation_dtype=np.uint8, + terminal_dtype=np.uint8, + action_shape=(), + action_dtype=np.int32, + reward_shape=(), + reward_dtype=np.float32, + ): + """Initializes OutOfGraphReplayBuffer. + + Args: + observation_shape: tuple of ints. + stack_size: int, number of frames to use in state stack. + replay_capacity: int, number of transitions to keep in memory. + batch_size: int. + subseq_len: int, length of subsequences to return. + n_envs: int, how many parallel environments will be writing data. + update_horizon: int, length of update ('n' in n-step update). + gamma: int, the discount factor. + max_sample_attempts: int, the maximum number of attempts allowed to get a + sample. + use_next_state: bool, whether to return separate "next_observation", + "next_reward" and "next_action" entries. Disable to reduce sampling time + in pure sequence modeling tasks. + extra_storage_types: list of ReplayElements defining the type of the extra + contents that will be stored and returned by sample_transition_batch. + observation_dtype: np.dtype, type of the observations. Defaults to + np.uint8 for Atari 2600. + terminal_dtype: np.dtype, type of the terminals. Defaults to np.uint8 for + Atari 2600. + action_shape: tuple of ints, the shape for the action vector. Empty tuple + means the action is a scalar. + action_dtype: np.dtype, type of elements in the action. + reward_shape: tuple of ints, the shape of the reward vector. Empty tuple + means the reward is a scalar. + reward_dtype: np.dtype, type of elements in the reward. + + Raises: + ValueError: If replay_capacity is too small to hold at least one + transition. + """ + assert isinstance(observation_shape, tuple) + if replay_capacity < update_horizon + stack_size: + raise ValueError('There is not enough capacity to cover ' + 'update_horizon and stack_size.') + + logging.info('Creating a %s replay memory with the following parameters:', + self.__class__.__name__) + logging.info('\t observation_shape: %s', str(observation_shape)) + logging.info('\t observation_dtype: %s', str(observation_dtype)) + logging.info('\t terminal_dtype: %s', str(terminal_dtype)) + logging.info('\t stack_size: %d', stack_size) + logging.info('\t use_next_state: %d', use_next_state) + logging.info('\t replay_capacity: %d', replay_capacity) + logging.info('\t batch_size: %d', batch_size) + logging.info('\t update_horizon: %d', update_horizon) + logging.info('\t gamma: %f', gamma) + + self._action_shape = action_shape + self._action_dtype = action_dtype + self._reward_shape = reward_shape + self._reward_dtype = reward_dtype + self._observation_shape = observation_shape + self._stack_size = stack_size + self._state_shape = self._observation_shape + (self._stack_size,) + self._batch_size = batch_size + self._update_horizon = update_horizon + self._gamma = gamma + self._observation_dtype = observation_dtype + self._terminal_dtype = terminal_dtype + self._max_sample_attempts = max_sample_attempts + self._subseq_len = subseq_len + self._use_next_state = use_next_state + + self._n_envs = n_envs + self._replay_length = int(replay_capacity // self._n_envs) + + # Gotta round this down, since the matrix is rectangular. + self._replay_capacity = self._replay_length * self._n_envs + + self.total_steps = 0 + + if extra_storage_types: + self._extra_storage_types = extra_storage_types + else: + self._extra_storage_types = [] + self._create_storage() + self.add_count = np.array(0) + self.invalid_range = np.zeros((self._stack_size)) + # When the horizon is > 1, we compute the sum of discounted rewards as a dot + # product using the precomputed vector . + self._cumulative_discount_vector = np.array( + [math.pow(self._gamma, n) for n in range(update_horizon + 1)], + dtype=np.float32) + self._next_experience_is_episode_start = True + self._episode_end_indices = set() + + def _create_storage(self): + """Creates the numpy arrays used to store transitions.""" + self._store = {} + for storage_element in self.get_storage_signature(): + array_shape = [self._replay_length, self._n_envs] + list( + storage_element.shape + ) + self._store[storage_element.name] = np.empty( + array_shape, dtype=storage_element.type) + + def get_add_args_signature(self): + """The signature of the add function. + + Note - Derived classes may return a different signature. + Returns: + list of ReplayElements defining the type of the argument signature + needed by the add function. + """ + return self.get_storage_signature() + + def get_storage_signature(self): + """Returns a default list of elements to be stored in this replay memory. + + Note - Derived classes may return a different signature. + Returns: + list of ReplayElements defining the type of the contents stored. + """ + storage_elements = [ + ReplayElement('observation', self._observation_shape, + self._observation_dtype), + ReplayElement('action', self._action_shape, self._action_dtype), + ReplayElement('reward', self._reward_shape, self._reward_dtype), + ReplayElement('terminal', (), self._terminal_dtype) + ] + + for extra_replay_element in self._extra_storage_types: + storage_elements.append(extra_replay_element) + return storage_elements + + def _add_zero_transition(self): + """Adds a padding transition filled with zeros (Used in episode beginnings). + """ + zero_transition = [] + for element_type in self.get_add_args_signature(): + zero_transition.append( + np.zeros(element_type.shape, dtype=element_type.type)) + self._episode_end_indices.discard(self.cursor()) # If present + self._add(*zero_transition) + + def add(self, + observation, + action, + reward, + terminal, + *args, + priority=None, + episode_end=False): + """Adds a transition to the replay memory. + + This function checks the types and handles the padding at the beginning + of an episode. Then it calls the _add function. + Since the next_observation in the transition will be the observation + added next there is no need to pass it. + If the replay memory is at capacity the oldest transition will be + discarded. + + Args: + observation: np.array with shape observation_shape. + action: int, the action in the transition. + reward: float, the reward received in the transition. + terminal: np.dtype, acts as a boolean indicating whether the transition + was terminal (1) or not (0). + *args: extra contents with shapes and dtypes according to + extra_storage_types. + priority: float, unused in the circular replay buffer, but may be used in + child classes like PrioritizedReplayBuffer. + episode_end: bool, whether this experience is the last experience in the + episode. This is useful for tasks that terminate due to time-out, but do + not end on a terminal state. Overloading 'terminal' may not be + sufficient in this case, since 'terminal' is passed to the agent for + training. 'episode_end' allows the replay buffer to determine episode + boundaries without passing that information to the agent. + """ + if priority is not None: + args = args + (priority,) + + self.total_steps += self._n_envs + self._check_add_types(observation, action, reward, terminal, *args) + + resets = episode_end + terminal + for i in range(resets.shape[0]): + if resets[i]: + self._episode_end_indices.add((self.cursor(), i)) + else: + self._episode_end_indices.discard((self.cursor(), i)) # If present + + self._add(observation, action, reward, terminal, *args) + + def _add(self, *args): + """Internal add method to add to the storage arrays. + + Args: + *args: All the elements in a transition. + """ + self._check_args_length(*args) + transition = { + e.name: args[idx] for idx, e in enumerate(self.get_add_args_signature()) + } + self._add_transition(transition) + + def _add_transition(self, transition): + """Internal add method to add transition dictionary to storage arrays. + + Args: + transition: The dictionary of names and values of the transition to add + to the storage. Each tensor should have leading dim equal to the + number of environments used by the buffer. + """ + cursor = self.cursor() + for arg_name in transition: + self._store[arg_name][cursor] = transition[arg_name] + + self.add_count += 1 + self.invalid_range = invalid_range(self.cursor(), self._replay_length, + self._stack_size, self._update_horizon) + + def _check_args_length(self, *args): + """Check if args passed to the add method have the same length as storage. + + Args: + *args: Args for elements used in storage. + + Raises: + ValueError: If args have wrong length. + """ + if len(args) != len(self.get_add_args_signature()): + raise ValueError('Add expects {} elements, received {}'.format( + len(self.get_add_args_signature()), len(args))) + + def _check_add_types(self, *args): + """Checks if args passed to the add method match those of the storage. + + Args: + *args: Args whose types need to be validated. + + Raises: + ValueError: If args have wrong shape or dtype. + """ + self._check_args_length(*args) + for i, (arg_element, store_element) in enumerate( + zip(args, self.get_add_args_signature())): + if hasattr(arg_element, 'shape'): + arg_shape = arg_element.shape + elif isinstance(arg_element, tuple) or isinstance(arg_element, list): + # TODO(b/80536437). This is not efficient when arg_element is a list. + arg_shape = np.array(arg_element).shape + else: + # Assume it is scalar. + arg_shape = tuple() + store_element_shape = tuple(store_element.shape) + assert arg_shape[0] == self._n_envs + arg_shape = arg_shape[1:] + if arg_shape != store_element_shape: + raise ValueError('arg {} has shape {}, expected {}'.format( + i, arg_shape, store_element_shape)) + + def is_empty(self): + """Is the Replay Buffer empty?""" + return self.add_count == 0 + + def is_full(self): + """Is the Replay Buffer full?""" + return self.add_count >= self._replay_length + + def ravel_indices(self, indices_t, indices_b): + return np.ravel_multi_index( + (indices_t, indices_b), (self._replay_length, self._n_envs), mode='wrap' + ) + + def unravel_indices(self, indices): + return np.unravel_index(indices, (self._replay_length, self._n_envs)) + + def get_from_store(self, element_name, indices_t, indices_b): + array = self._store[element_name] + return array[indices_t, indices_b] + + def cursor(self): + """Index to the location where the next transition will be written.""" + return self.add_count % self._replay_length + + def parallel_get_stack(self, element_name, indices_t, indices_b, first_valid): + indices_t = np.arange(-self._stack_size + 1, 1)[:, + None] + indices_t[None, :] + indices_b = indices_b[None, :].repeat(self._stack_size, axis=0) + mask = indices_t >= first_valid + result = self.get_from_store(element_name, indices_t % self._replay_length, + indices_b) + mask = mask.reshape(*mask.shape, *([1] * (len(result.shape) - 2))) + result = result * mask + result = np.moveaxis(result, 0, -1) + return result + + def get_terminal_stack(self, index_t, index_b): + return self.parallel_get_stack('terminal', index_t, index_b, 0) + + def is_valid_transition(self, index_t, index_b): + """Checks if the index contains a valid transition. + + Checks for collisions with the end of episodes and the current position + of the cursor. + Args: + index_t: int, index in the time dimension of the state. + index_b: int, index in the environment dimension of the state. + + Returns: + Is the index valid: Boolean. + Start of the current episode (if within our stack size): Integer. + """ + # Check the index is in the valid range + if index_t < 0 or index_t >= self._replay_length: + return False, 0 + if not self.is_full(): + # The indices and next_indices must be smaller than the cursor. + if index_t >= self.cursor() - self._update_horizon - self._subseq_len: + return False, 0 + # The first few indices contain the padding states of the first episode. + if index_t < self._stack_size - 1: + return False, 0 + + # Skip transitions that straddle the cursor. + if index_t[0] in set(self.invalid_range): + return False, 0 + + # If there are terminal flags in any other frame other than the last one + # the stack is not valid, so don't sample it. + terminals = self.get_terminal_stack(index_t, index_b)[0, :-1] + if terminals.any(): + ep_start = index_t - self._stack_size + terminals.argmax() + 2 + else: + ep_start = 0 + + # If the episode ends before the update horizon, without a terminal signal, + # it is invalid. + for i in modulo_range(index_t, self._update_horizon, self._replay_length): + if (i.item(), index_b.item( + )) in self._episode_end_indices and not self._store['terminal'][i, + index_b]: + return False, 0 + + return True, ep_start + + def _create_batch_arrays(self, batch_size): + """Create a tuple of arrays with the type of get_transition_elements. + + When using the WrappedReplayBuffer with staging enabled it is important to + create new arrays every sample because StaginArea keeps a pointer to the + returned arrays. + Args: + batch_size: (int) number of transitions returned. If None the default + batch_size will be used. + + Returns: + Tuple of np.arrays with the shape and type of + get_transition_elements. + """ + transition_elements = self.get_transition_elements(batch_size) + batch_arrays = [] + for element in transition_elements: + batch_arrays.append(np.empty(element.shape, dtype=element.type)) + return tuple(batch_arrays) + + def num_elements(self): + if self.is_full(): + return self._replay_capacity + else: + return self.cursor() * self._n_envs + + def sample_index_batch(self, batch_size): + """Returns a batch of valid indices sampled uniformly. + + Args: + batch_size: int, number of indices returned. + + Returns: + list of ints, a batch of valid indices sampled uniformly. + + Raises: + RuntimeError: If the batch was not constructed after maximum number + of tries. + """ + self._rng, rng = jax.random.split(self._rng) + if self.is_full(): + # add_count >= self._replay_capacity > self._stack_size + min_id = self.cursor() - self._replay_length + self._stack_size - 1 + max_id = self.cursor() - self._update_horizon - self._subseq_len + else: + # add_count < self._replay_capacity + min_id = self._stack_size - 1 + max_id = self.cursor() - self._update_horizon - self._subseq_len + if max_id <= min_id: + raise RuntimeError('Cannot sample a batch with fewer than stack size ' + '({}) + update_horizon ({}) transitions.'.format( + self._stack_size, self._update_horizon)) + t_indices = jax.random.randint(rng, (batch_size,), min_id, + max_id) % self._replay_length + b_indices = jax.random.randint(rng, (batch_size,), 0, self._n_envs) + allowed_attempts = self._max_sample_attempts + t_indices = np.array(t_indices) + censor_before = np.zeros_like(t_indices) + for i in range(len(t_indices)): + is_valid, ep_start = self.is_valid_transition(t_indices[i:i + 1], + b_indices[i:i + 1]) + censor_before[i] = ep_start + if not is_valid: + if allowed_attempts == 0: + raise RuntimeError( + 'Max sample attempts: Tried {} times but only sampled {}' + ' valid indices. Batch size is {}'.format( + self._max_sample_attempts, i, batch_size)) + while not is_valid and allowed_attempts > 0: + # If index i is not valid keep sampling others. Note that this + # is not stratified. + self._rng, rng = jax.random.split(self._rng) + t_index = jax.random.randint(rng, (1,), min_id, + max_id) % self._replay_length + b_index = jax.random.randint(rng, (1,), 0, self._n_envs) + allowed_attempts -= 1 + t_indices[i] = t_index + b_indices[i] = b_index + is_valid, first_valid = self.is_valid_transition( + t_indices[i:i + 1], b_indices[i:i + 1]) + censor_before[i] = first_valid + return t_indices, b_indices, censor_before + + def restore_leading_dims(self, batch_size, subseq_len, tensor): + return tensor.reshape(batch_size, subseq_len, *tensor.shape[1:]) + + def sample(self, *args, **kwargs): + return self.sample_transition_batch(*args, **kwargs) + + def sample_transition_batch( + self, + rng=None, + batch_size=None, + indices=None, + subseq_len=None, + update_horizon=None, + gamma=None, + ): + """Returns a batch of transitions (including any extra contents). + + If get_transition_elements has been overridden and defines elements not + stored in self._store, an empty array will be returned and it will be + left to the child class to fill it. For example, for the child class + OutOfGraphPrioritizedReplayBuffer, the contents of the + sampling_probabilities are stored separately in a sum tree. + When the transition is terminal next_state_batch has undefined contents. + NOTE: This transition contains the indices of the sampled elements. + These + are only valid during the call to sample_transition_batch, i.e. they may + be used by subclasses of this replay buffer but may point to different + data + as soon as sampling is done. + Args: + rng: Jax PRNG key, if overriding the default buffer state. + batch_size: int, number of transitions returned. If None, the default + batch_size will be used. + indices: None or list of ints, the indices of every transition in the + batch. If None, sample the indices uniformly. + subseq_len: The length of subsequence to sample. Can override the replay + buffer default. + update_horizon: Update horizon to use, if overriding the original setting. + gamma: Discount factor to use, if overriding the original setting. + + Returns: + transition_batch: tuple of np.arrays with the shape and type as in + get_transition_elements(). + Raises: + ValueError: If an element to be sampled is missing from the replay + buffer. + """ + self._rng = rng if rng is not None else self._rng + if batch_size is None: + batch_size = self._batch_size + if subseq_len is None: + subseq_len = self._subseq_len + if update_horizon is None: + update_horizon = self._update_horizon + if indices is None: + t_indices, b_indices, censor_before = self.sample_index_batch(batch_size) + if gamma is None: + cumulative_discount_vector = self._cumulative_discount_vector + else: + cumulative_discount_vector = np.array( + [math.pow(gamma, n) for n in range(update_horizon + 1)], + dtype=np.float32, + ) + assert len(t_indices) == batch_size + assert len(b_indices) == batch_size + transition_elements = self.get_transition_elements(batch_size) + state_indices = t_indices[:, None] + np.arange(subseq_len)[None, :] + state_indices = state_indices.reshape( + batch_size * subseq_len) % self._replay_length + b_indices = b_indices[:, None].repeat( + subseq_len, axis=1).reshape(batch_size * subseq_len) + censor_before = censor_before[:, None].repeat( + subseq_len, axis=1).reshape(batch_size * subseq_len) + + # shape: horizon X batch_size*subseq_len + # Offset by one; a `d + trajectory_indices = (np.arange(-1, update_horizon - 1)[:, None] + + state_indices[None, :]) % self._replay_length + trajectory_b_indices = b_indices[None,].repeat(update_horizon, axis=0) + trajectory_terminals = self._store['terminal'][trajectory_indices, + trajectory_b_indices] + trajectory_terminals[0, :] = 0 + is_terminal_transition = trajectory_terminals.any(0) + valid_mask = (1 - trajectory_terminals).cumprod(0) + trajectory_discount_vector = valid_mask * ( + cumulative_discount_vector[:update_horizon, None] + ) + trajectory_rewards = self._store['reward'][(trajectory_indices + 1) % + self._replay_length, + trajectory_b_indices] + + returns = np.cumsum(trajectory_discount_vector * trajectory_rewards, axis=0) + + update_horizons = jnp.ones( + batch_size * subseq_len, dtype=jnp.int32) * ( + update_horizon - 1) + returns = returns[update_horizons, np.arange(batch_size * subseq_len)] + + next_indices = (state_indices + update_horizons) % self._replay_length + outputs = [] + for element in transition_elements: + name = element.name + if name == 'state': + output = self.parallel_get_stack( + 'observation', + state_indices, + b_indices, + censor_before, + ) + output = self.restore_leading_dims(batch_size, subseq_len, output) + elif name == 'return': + # compute the discounted sum of rewards in the trajectory. + output = returns + output = self.restore_leading_dims(batch_size, subseq_len, output) + elif name == 'discount': + # compute the discounted sum of rewards in the trajectory. + output = cumulative_discount_vector[update_horizons + 1] + output = self.restore_leading_dims(batch_size, subseq_len, output) + elif name == 'next_state': + output = self.parallel_get_stack( + 'observation', + next_indices, + b_indices, + censor_before, + ) + output = self.restore_leading_dims(batch_size, subseq_len, output) + elif name == 'same_trajectory': + output = self._store['terminal'][state_indices, b_indices] + output = self.restore_leading_dims(batch_size, subseq_len, output) + output[0, :] = 0 + output = (1 - output).cumprod(1) + elif name in ('next_action', 'next_reward'): + output = self._store[name.lstrip('next_')][next_indices, b_indices] + output = self.restore_leading_dims(batch_size, subseq_len, output) + elif element.name == 'terminal': + output = is_terminal_transition + output = self.restore_leading_dims(batch_size, subseq_len, output) + elif name == 'indices': + output = self.ravel_indices(state_indices, b_indices).astype('int32') + output = self.restore_leading_dims(batch_size, subseq_len, output)[:, 0] + elif name in self._store.keys(): + output = self._store[name][state_indices, b_indices] + output = self.restore_leading_dims(batch_size, subseq_len, output) + else: + continue + outputs.append(output) + return outputs + + def get_transition_elements(self, batch_size=None, subseq_len=None): + """Returns a 'type signature' for sample_transition_batch. + + Args: + batch_size: int, number of transitions returned. If None, the default + batch_size will be used. + subseq_len: int, length of subsequences to return. + + Returns: + signature: A namedtuple describing the method's return type signature. + """ + subseq_len = self._subseq_len if subseq_len is None else subseq_len + batch_size = self._batch_size if batch_size is None else batch_size + + transition_elements = [ + ReplayElement('state', (batch_size, subseq_len) + self._state_shape, + self._observation_dtype), + ReplayElement('action', (batch_size, subseq_len) + self._action_shape, + self._action_dtype), + ReplayElement('reward', (batch_size, subseq_len) + self._reward_shape, + self._reward_dtype), + ReplayElement('return', (batch_size, subseq_len) + self._reward_shape, + self._reward_dtype), + ReplayElement('discount', (), self._reward_dtype), + ] + if self._use_next_state: + transition_elements += [ + ReplayElement('next_state', + (batch_size, subseq_len) + self._state_shape, + self._observation_dtype), + ReplayElement('next_action', + (batch_size, subseq_len) + self._action_shape, + self._action_dtype), + ReplayElement('next_reward', + (batch_size, subseq_len) + self._reward_shape, + self._reward_dtype), + ] + transition_elements += [ + ReplayElement('terminal', (batch_size, subseq_len), + self._terminal_dtype), + ReplayElement('same_trajectory', (batch_size, subseq_len), + self._terminal_dtype), + ReplayElement('indices', (batch_size,), np.int32) + ] + for element in self._extra_storage_types: + transition_elements.append( + ReplayElement(element.name, + (batch_size, subseq_len) + tuple(element.shape), + element.type)) + return transition_elements + + def _generate_filename(self, checkpoint_dir, name, suffix): + return os.path.join(checkpoint_dir, '{}_ckpt.{}.gz'.format(name, suffix)) + + def _return_checkpointable_elements(self): + """Return the dict of elements of the class for checkpointing. + + Returns: + checkpointable_elements: dict containing all non private (starting + with _) members + all the arrays inside self._store. + """ + checkpointable_elements = {} + for member_name, member in self.__dict__.items(): + if member_name == '_store': + for array_name, array in self._store.items(): + checkpointable_elements[STORE_FILENAME_PREFIX + array_name] = array + elif not member_name.startswith('_'): + checkpointable_elements[member_name] = member + return checkpointable_elements + + def save(self, checkpoint_dir, iteration_number): + """Save the OutOfGraphReplayBuffer attributes into a file. + + This method will save all the replay buffer's state in a single file. + Args: + checkpoint_dir: str, the directory where numpy checkpoint files should be + saved. + iteration_number: int, iteration_number to use as a suffix in naming numpy + checkpoint files. + """ + if not tf.io.gfile.exists(checkpoint_dir): + return + + checkpointable_elements = self._return_checkpointable_elements() + + for attr in checkpointable_elements: + filename = self._generate_filename(checkpoint_dir, attr, iteration_number) + with tf.io.gfile.GFile(filename, 'wb') as f: + with gzip.GzipFile(fileobj=f) as outfile: + # Checkpoint the np arrays in self._store with np.save instead of + # pickling the dictionary is critical for file size and performance. + # STORE_FILENAME_PREFIX indicates that the variable is contained in + # self._store. + if attr.startswith(STORE_FILENAME_PREFIX): + array_name = attr[len(STORE_FILENAME_PREFIX):] + np.save(outfile, self._store[array_name], allow_pickle=False) + # Some numpy arrays might not be part of storage + elif isinstance(self.__dict__[attr], np.ndarray): + np.save(outfile, self.__dict__[attr], allow_pickle=False) + else: + pickle.dump(self.__dict__[attr], outfile) + + # After writing a checkpoint file, we garbage collect the checkpoint file + # that is four versions old. + stale_iteration_number = iteration_number - CHECKPOINT_DURATION + if stale_iteration_number >= 0: + stale_filename = self._generate_filename(checkpoint_dir, attr, + stale_iteration_number) + try: + tf.io.gfile.remove(stale_filename) + except tf.errors.NotFoundError: + pass + + def load(self, checkpoint_dir, suffix): + """Restores the object from bundle_dictionary and numpy checkpoints. + + Args: + checkpoint_dir: str, the directory where to read the numpy checkpointed + files from. + suffix: str, the suffix to use in numpy checkpoint files. + + Raises: + NotFoundError: If not all expected files are found in directory. + """ + save_elements = self._return_checkpointable_elements() + # We will first make sure we have all the necessary files available to avoid + # loading a partially-specified (i.e. corrupted) replay buffer. + for attr in save_elements: + filename = self._generate_filename(checkpoint_dir, attr, suffix) + if not tf.io.gfile.exists(filename): + raise tf.errors.NotFoundError(None, None, + 'Missing file: {}'.format(filename)) + # If we've reached this point then we have verified that all expected files + # are available. + for attr in save_elements: + filename = self._generate_filename(checkpoint_dir, attr, suffix) + with tf.io.gfile.GFile(filename, 'rb') as f: + with gzip.GzipFile(fileobj=f) as infile: + if attr.startswith(STORE_FILENAME_PREFIX): + array_name = attr[len(STORE_FILENAME_PREFIX):] + self._store[array_name] = np.load(infile, allow_pickle=False) + elif isinstance(self.__dict__[attr], np.ndarray): + self.__dict__[attr] = np.load(infile, allow_pickle=False) + else: + self.__dict__[attr] = pickle.load(infile) + + def reset_priorities(self): + pass + + +@gin.configurable +class PrioritizedJaxSubsequenceParallelEnvReplayBuffer( + JaxSubsequenceParallelEnvReplayBuffer): + """Deterministic version of prioritized replay buffer.""" + + def __init__(self, + observation_shape, + stack_size, + replay_capacity, + batch_size, + update_horizon=1, + subseq_len=0, + n_envs=1, + gamma=0.99, + max_sample_attempts=1000, + extra_storage_types=None, + observation_dtype=np.uint8, + terminal_dtype=np.uint8, + action_shape=(), + action_dtype=np.int32, + reward_shape=(), + reward_dtype=np.float32): + super().__init__( + observation_shape=observation_shape, + stack_size=stack_size, + replay_capacity=int(replay_capacity), + batch_size=batch_size, + update_horizon=update_horizon, + gamma=gamma, + max_sample_attempts=max_sample_attempts, + extra_storage_types=extra_storage_types, + observation_dtype=observation_dtype, + terminal_dtype=terminal_dtype, + subseq_len=subseq_len, + n_envs=n_envs, + action_shape=action_shape, + action_dtype=action_dtype, + reward_shape=reward_shape, + reward_dtype=reward_dtype) + + self.sum_tree = sum_tree.DeterministicSumTree(int(replay_capacity)) + + def get_add_args_signature(self): + """The signature of the add function.""" + parent_add_signature = super().get_add_args_signature() + add_signature = parent_add_signature + [ + ReplayElement('priority', (), np.float32) + ] + return add_signature + + def _add(self, *args): + """Internal add method to add to the underlying memory arrays.""" + self._check_args_length(*args) + + # Use Schaul et al.'s (2015) scheme of setting the priority of new elements + # to the maximum priority so far. + # Picks out 'priority' from arguments and adds it to the sum_tree. + transition = {} + for i, element in enumerate(self.get_add_args_signature()): + if element.name == 'priority': + priority = args[i] + else: + transition[element.name] = args[i] + + indices = np.ravel_multi_index( + (np.ones((1,), dtype='int32') * self.cursor(), np.arange(self._n_envs)), + (self._replay_length, self._n_envs), + ) + + for i in range(len(indices)): + self.sum_tree.set(indices[i], priority[i]) + super()._add_transition(transition) + + def sample_index_batch(self, batch_size): + """Returns a batch of valid indices sampled as in Schaul et al. (2015).""" + # Sample stratified indices. Some of them might be invalid. + # start = time.time() + indices = self.sum_tree.stratified_sample(batch_size, self._rng) + indices = np.array(indices) + # print("Sampling from sum tree took {}".format(time.time() - start)) + allowed_attempts = self._max_sample_attempts + + t_indices, b_indices = self.unravel_indices(indices) # pylint: disable=unbalanced-tuple-unpacking + censor_before = np.zeros_like(t_indices) + for i in range(len(indices)): + is_valid, ep_start = self.is_valid_transition(t_indices[i:i + 1], + b_indices[i:i + 1]) + censor_before[i] = ep_start + if not is_valid: + if allowed_attempts == 0: + raise RuntimeError( + 'Max sample attempts: Tried {} times but only sampled {}' + ' valid indices. Batch size is {}'.format( + self._max_sample_attempts, i, batch_size)) + while (not is_valid) and allowed_attempts > 0: + # If index i is not valid keep sampling others. Note that this + # is not stratified. + self._rng, rng = jax.random.split(self._rng) + index = int(self.sum_tree.stratified_sample(1, rng=rng)) + t_index, b_index = self.unravel_indices(index) # pylint: disable=unbalanced-tuple-unpacking + + allowed_attempts -= 1 + t_indices[i] = t_index + b_indices[i] = b_index + is_valid, ep_start = self.is_valid_transition(t_indices[i:i + 1], + b_indices[i:i + 1]) + censor_before[i] = ep_start + return t_indices, b_indices, censor_before + + def sample_transition_batch( + self, + rng, + batch_size=None, + indices=None, + subseq_len=None, + update_horizon=None, + gamma=None, + ): + """Returns a batch of transitions with extra storage and the priorities.""" + transition = super().sample_transition_batch( + rng, + batch_size, + indices, + subseq_len=subseq_len, + update_horizon=update_horizon, + gamma=gamma, + ) + transition.append(self.get_priority(transition[-1])) + return transition + + def set_priority(self, indices, priorities): + """Sets the priority of the given elements according to Schaul et al.""" + assert indices.dtype == np.int32, ('Indices must be integers, ' + 'given: {}'.format(indices.dtype)) + priorities = np.asarray(priorities) + indices = np.asarray(indices) + for index, priority in zip(indices, priorities): + self.sum_tree.set(index, priority) + + def get_priority(self, indices): + """Fetches the priorities correspond to a batch of memory indices.""" + assert indices.shape, 'Indices must be an array.' + assert indices.dtype == np.int32, ('Indices must be int32s, ' + 'given: {}'.format(indices.dtype)) + priority_batch = self.sum_tree.get(indices) + return priority_batch + + def get_transition_elements(self, batch_size=None): + """Returns a 'type signature' for sample_transition_batch.""" + parent_transition_type = (super().get_transition_elements(batch_size)) + probablilities_type = [ + ReplayElement('sampling_probabilities', (batch_size,), np.float32) + ] + return parent_transition_type + probablilities_type + + def reset_priorities(self): + self.sum_tree.reset_priorities() diff --git a/dopamine/labs/atari_100k/spr_agent.py b/dopamine/labs/atari_100k/spr_agent.py new file mode 100644 index 00000000..1761a17c --- /dev/null +++ b/dopamine/labs/atari_100k/spr_agent.py @@ -0,0 +1,466 @@ +# coding=utf-8 +# Copyright 2023 The Dopamine Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An implementation of SPR in Jax. + +Includes the features included in the full Rainbow agent. Designed to work with +an optimized replay buffer that returns subsequences rather than individual +transitions. +Some details differ from the original implementation due to differences in +the underlying Rainbow implementations. In particular: +* Dueling networks in Dopamine separate at the final layer, not the penultimate + layer as in the original. +* Dopamine's prioritized experience replay does not decay its exponent + over time. +We find that these changes do not drastically impact the overall performance of +the algorithm, however. +Details on Rainbow are available in +"Rainbow: Combining Improvements in Deep Reinforcement Learning" by Hessel et +al. (2018). For details on SPR, see +"Data-Efficient Reinforcement Learning with Self-Predictive Representations" by +Schwarzer et al (2021). +""" + +import collections +import copy +import functools +import time + +from absl import logging +from dopamine.jax import losses +from dopamine.jax.agents.dqn import dqn_agent +from dopamine.jax.agents.rainbow import rainbow_agent as dopamine_rainbow_agent +from dopamine.labs.atari_100k import atari_100k_rainbow_agent +from dopamine.labs.atari_100k.replay_memory import subsequence_replay_buffer as replay_buffers +from dopamine.labs.atari_100k.spr_networks import SPRNetwork +import gin +import jax +import jax.numpy as jnp +import optax +import tensorflow as tf + + +@functools.partial( + jax.vmap, in_axes=(None, 0, 0, None, None), axis_name='batch' +) +def get_logits(model, states, actions, do_rollout, rng): + results = model(states, actions=actions, do_rollout=do_rollout, key=rng)[0] + return results.logits, results.latent + + +@functools.partial( + jax.vmap, in_axes=(None, 0, 0, None, None), axis_name='batch' +) +def get_q_values(model, states, actions, do_rollout, rng): + results = model(states, actions=actions, do_rollout=do_rollout, key=rng)[0] + return results.q_values, results.latent + + +@functools.partial(jax.vmap, in_axes=(None, 0, None), axis_name='batch') +def get_spr_targets(model, states, key): + results = model(states, key) + return results + + +@functools.partial( + jax.jit, + static_argnames=( + 'network_def', + 'optimizer', + 'double_dqn', + 'distributional', + 'spr_weight', + 'cumulative_gamma', + ), +) +def train( + network_def, + online_params, + target_params, + optimizer, + optimizer_state, + states, + actions, + next_states, + rewards, + terminals, + same_traj_mask, + loss_weights, + support, + cumulative_gamma, + double_dqn, + distributional, + rng, + spr_weight, +): + """Run a training step.""" + + current_state = states[:, 0] + # Split the current rng into 2 for updating the rng after this call + rng, rng1, rng2 = jax.random.split(rng, num=3) + use_spr = spr_weight > 0 + + def q_online(state, key, actions=None, do_rollout=False): + return network_def.apply( + online_params, + state, + actions=actions, + do_rollout=do_rollout, + key=key, + support=support, + mutable=['batch_stats'], + ) + + def q_target(state, key): + return network_def.apply( + target_params, state, key=key, support=support, mutable=['batch_stats'] + ) + + def encode_project(state, key): + latent, _ = network_def.apply( + target_params, state, method=network_def.encode, mutable=['batch_stats'] + ) + latent = latent.reshape(-1) + return network_def.apply( + target_params, + latent, + key=key, + eval_mode=True, + method=network_def.project, + ) + + def loss_fn(params, target, spr_targets, loss_multipliers): + """Computes the distributional loss for C51 or huber loss for DQN.""" + + def q_online(state, key, actions=None, do_rollout=False): + return network_def.apply( + params, + state, + actions=actions, + do_rollout=do_rollout, + key=key, + support=support, + mutable=['batch_stats'], + ) + + if distributional: + (logits, spr_predictions) = get_logits( + q_online, current_state, actions[:, :-1], use_spr, rng + ) + logits = jnp.squeeze(logits) + # Fetch the logits for its selected action. We use vmap to perform this + # indexing across the batch. + chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions[:, 0]) + dqn_loss = jax.vmap(losses.softmax_cross_entropy_loss_with_logits)( + target, chosen_action_logits + ) + else: + q_values, spr_predictions = get_q_values( + q_online, current_state, actions[:, :-1], use_spr, rng + ) + q_values = jnp.squeeze(q_values) + replay_chosen_q = jax.vmap(lambda x, y: x[y])(q_values, actions[:, 0]) + dqn_loss = jax.vmap(losses.huber_loss)(target, replay_chosen_q) + + if use_spr: + # transpose to move from (time, batch, latent_dim) to + # (batch, time, latent_dim) to match targets + spr_predictions = spr_predictions.transpose(1, 0, 2) + + # Calculate SPR loss (normalized L2) + spr_predictions = spr_predictions / jnp.linalg.norm( + spr_predictions, 2, -1, keepdims=True + ) + spr_targets = spr_targets / jnp.linalg.norm( + spr_targets, 2, -1, keepdims=True + ) + + spr_loss = jnp.power(spr_predictions - spr_targets, 2).sum(-1) + + # Zero out loss for predictions that cross into the next episode + spr_loss = (spr_loss * same_traj_mask.transpose(1, 0)).mean(0) + else: + spr_loss = 0 + + loss = dqn_loss + spr_weight * spr_loss + + mean_loss = jnp.mean(loss_multipliers * loss) + return mean_loss, (loss, dqn_loss, spr_loss) + + # Use the weighted mean loss for gradient computation. + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + target = target_output( + q_online, + q_target, + next_states, + rewards, + terminals, + support, + cumulative_gamma, + double_dqn, + distributional, + rng1, + ) + + if use_spr: + future_states = states[:, 1:] + spr_targets = get_spr_targets( + encode_project, + future_states.reshape(-1, *future_states.shape[2:]), + rng1, + ) + spr_targets = spr_targets.reshape( + *future_states.shape[:2], *spr_targets.shape[1:] + ).transpose(1, 0, 2) + else: + spr_targets = None + + # Get the unweighted loss without taking its mean for updating priorities. + (mean_loss, (_, dqn_loss, spr_loss)), grad = grad_fn( + online_params, target, spr_targets, loss_weights + ) + updates, optimizer_state = optimizer.update(grad, optimizer_state) + online_params = optax.apply_updates(online_params, updates) + return optimizer_state, online_params, mean_loss, dqn_loss, spr_loss, rng2 + + +@functools.partial( + jax.vmap, + in_axes=(None, None, 0, 0, 0, None, None, None, None, None), + axis_name='batch', +) +def target_output( + model, + target_network, + next_states, + rewards, + terminals, + support, + cumulative_gamma, + double_dqn, + distributional, + rng, +): + """Builds the C51 target distribution or DQN target Q-values.""" + is_terminal_multiplier = 1.0 - terminals.astype(jnp.float32) + # Incorporate terminal state to discount factor. + gamma_with_terminal = cumulative_gamma * is_terminal_multiplier + + target_network_dist, _ = target_network(next_states, key=rng) + if double_dqn: + # Use the current network for the action selection + next_state_target_outputs, _ = model(next_states, key=rng) + else: + next_state_target_outputs = target_network_dist + # Action selection using Q-values for next-state + q_values = jnp.squeeze(next_state_target_outputs.q_values) + next_qt_argmax = jnp.argmax(q_values) + + if distributional: + # Compute the target Q-value distribution + probabilities = jnp.squeeze(target_network_dist.probabilities) + next_probabilities = probabilities[next_qt_argmax] + target_support = rewards + gamma_with_terminal * support + target = dopamine_rainbow_agent.project_distribution( + target_support, next_probabilities, support + ) + else: + # Compute the target Q-value + next_q_values = jnp.squeeze(target_network_dist.q_values) + replay_next_qt_max = next_q_values[next_qt_argmax] + target = rewards + gamma_with_terminal * replay_next_qt_max + + return jax.lax.stop_gradient(target) + + +@gin.configurable +class SPRAgent(atari_100k_rainbow_agent.Atari100kRainbowAgent): + """A compact implementation of SPR in Jax.""" + + def __init__( + self, + num_actions, + jumps=5, + spr_weight=5, + summary_writer=None, + seed=None, + epsilon_fn=dqn_agent.linearly_decaying_epsilon, + network=SPRNetwork, + ): + """Initializes the agent and constructs the necessary components. + + Args: + num_actions: int, number of actions the agent can take at any state. + jumps: int >= 0, number of SPR prediction steps to do. + spr_weight: float, weight given to the SPR loss. + summary_writer: SummaryWriter object, for outputting training statistics. + seed: int, a seed for Jax RNG and initialization. + epsilon_fn: Type of epsilon decay to use. By default, linearly_decaying + will use e-greedy during initial data collection, matching the PyTorch + codebase. + network: Network class to use. + """ + logging.info( + 'Creating %s agent with the following parameters:', + self.__class__.__name__, + ) + logging.info('\t spr_weight: %s', spr_weight) + logging.info('\t jumps: %s', jumps) + self._jumps = jumps + self.spr_weight = spr_weight + super().__init__( + num_actions=num_actions, + summary_writer=summary_writer, + seed=seed, + network=network, + ) + + # Parent class JaxFullRainbowAgent will overwrite this with the wrong value, + # so just reverse its change. + self.epsilon_fn = epsilon_fn + self.start_time = time.time() + + def _build_networks_and_optimizer(self): + self._rng, rng = jax.random.split(self._rng) + self.online_params = self.network_def.init( + rng, + x=self.state, + actions=jnp.zeros((5,)), + do_rollout=self.spr_weight > 0, + support=self._support, + ) + self.optimizer = dqn_agent.create_optimizer(self._optimizer_name) + self.optimizer_state = self.optimizer.init(self.online_params) + self.target_network_params = copy.deepcopy(self.online_params) + + def _build_replay_buffer(self): + """Creates the replay buffer used by the agent.""" + if self._replay_scheme not in ['uniform', 'prioritized']: + raise ValueError('Invalid replay scheme: {}'.format(self._replay_scheme)) + if self._replay_scheme == 'prioritized': + return replay_buffers.PrioritizedJaxSubsequenceParallelEnvReplayBuffer( + observation_shape=self.observation_shape, + stack_size=self.stack_size, + update_horizon=self.update_horizon, + gamma=self.gamma, + subseq_len=self._jumps + 1, + observation_dtype=self.observation_dtype, + ) + else: + return replay_buffers.JaxSubsequenceParallelEnvReplayBuffer( + observation_shape=self.observation_shape, + stack_size=self.stack_size, + update_horizon=self.update_horizon, + gamma=self.gamma, + subseq_len=self._jumps + 1, + observation_dtype=self.observation_dtype, + ) + + def _sample_from_replay_buffer(self): + self._rng, rng = jax.random.split(self._rng) + samples = self._replay.sample_transition_batch(rng) + types = self._replay.get_transition_elements() + self.replay_elements = collections.OrderedDict() + for element, element_type in zip(samples, types): + self.replay_elements[element_type.name] = element + + def _training_step_update(self): + """Gradient update during every training step.""" + + inter_batch_time = time.time() - self.start_time + self.start_time = time.time() + + sample_start_time = time.time() + self._sample_from_replay_buffer() + sample_time = time.time() - sample_start_time + + aug_start_time = time.time() + # Add code for data augmentation. + self._rng, rng1, rng2 = jax.random.split(self._rng, num=3) + states = self.train_preprocess_fn(self.replay_elements['state'], rng=rng1) + next_states = self.train_preprocess_fn( + self.replay_elements['next_state'][:, 0], rng=rng2 + ) + + aug_time = time.time() - aug_start_time + train_start_time = time.time() + + if self._replay_scheme == 'prioritized': + # The original prioritized experience replay uses a linear exponent + # schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of + # 0.5 on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders) + # suggested a fixed exponent actually performs better, except on Pong. + probs = self.replay_elements['sampling_probabilities'] + # Weight the loss by the inverse priorities. + loss_weights = 1.0 / jnp.sqrt(probs + 1e-10) + loss_weights /= jnp.max(loss_weights) + else: + # Uniform weights if not using prioritized replay. + loss_weights = jnp.ones(states.shape[0]) + + ( + self.optimizer_state, + self.online_params, + mean_loss, + dqn_loss, + spr_loss, + self._rng, + ) = train( + network_def=self.network_def, + online_params=self.online_params, + target_params=self.target_network_params, + optimizer=self.optimizer, + optimizer_state=self.optimizer_state, + states=states, + actions=self.replay_elements['action'], + next_states=next_states, + rewards=self.replay_elements['reward'][:, 0], + terminals=self.replay_elements['terminal'][:, 0], + same_traj_mask=self.replay_elements['same_trajectory'][:, 1:], + loss_weights=loss_weights, + support=self._support, + cumulative_gamma=self.cumulative_gamma, + double_dqn=self._double_dqn, + distributional=self._distributional, + rng=self._rng, + spr_weight=self.spr_weight, + ) + + if self._replay_scheme == 'prioritized': + # Rainbow and prioritized replay are parametrized by an exponent + # alpha, but in both cases it is set to 0.5 - for simplicity's sake we + # leave it as is here, using the more direct sqrt(). Taking the square + # root "makes sense", as we are dealing with a squared loss. Add a + # small nonzero value to the loss to avoid 0 priority items. While + # technically this may be okay, setting all items to 0 priority will + # cause troubles, and also result in 1.0 / 0.0 = NaN correction terms. + self._replay.set_priority( + self.replay_elements['indices'], jnp.sqrt(dqn_loss + 1e-10) + ) + + train_time = time.time() - train_start_time + if ( + self.summary_writer is not None + and self.training_steps > 0 + and self.training_steps % self.summary_writing_frequency == 0 + ): + step = self.training_steps + tf.summary.scalar('TotalLoss', float(mean_loss), step=step) + tf.summary.scalar('DQNLoss', float(dqn_loss.mean()), step=step) + tf.summary.scalar('SPRLoss', float(spr_loss.mean()), step=step) + tf.summary.scalar('InterbatchTime', float(inter_batch_time), step=step) + tf.summary.scalar('TrainTime', float(train_time), step=step) + tf.summary.scalar('SampleTime', float(sample_time), step=step) + tf.summary.scalar('AugTime', float(aug_time), step=step) + self.summary_writer.flush() diff --git a/dopamine/labs/atari_100k/spr_networks.py b/dopamine/labs/atari_100k/spr_networks.py new file mode 100644 index 00000000..44184c2e --- /dev/null +++ b/dopamine/labs/atari_100k/spr_networks.py @@ -0,0 +1,454 @@ +# coding=utf-8 +# Copyright 2023 The Dopamine Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Networks for SPR in Jax+Dopamine.""" + +import collections +import functools +import time +from typing import Any, Callable, Optional, Tuple + +from dopamine.jax.networks import preprocess_atari_inputs +from flax import linen as nn +import gin +import jax +from jax import lax +from jax import random +from jax.nn import initializers +import jax.numpy as jnp +import numpy as onp + +PRNGKey = Any +Array = Any +Shape = Tuple[int] +Dtype = Any + + +SPROutputType = collections.namedtuple( + 'RL_network', ['q_values', 'logits', 'probabilities', 'latent'] +) + + +def _absolute_dims(rank, dims): + return tuple([rank + dim if dim < 0 else dim for dim in dims]) + + +# --------------------------- < NoisyNetwork >--------------------------------- +# Noisy networks for SPR need to be called multiple times with and without +# noise, so we have a slightly customized implementation where eval_mode +# is an argument to __call__ rather than an attribute of the class. +@gin.configurable +class NoisyNetwork(nn.Module): + """Noisy Network from Fortunato et al. (2018).""" + + features: int = 512 + + @staticmethod + def sample_noise(key, shape): + return random.normal(key, shape) + + @staticmethod + def f(x): + # See (10) and (11) in Fortunato et al. (2018). + return jnp.multiply(jnp.sign(x), jnp.power(jnp.abs(x), 0.5)) + + @nn.compact + def __call__(self, x, rng_key, bias=True, kernel_init=None, eval_mode=False): + """Call the noisy layer. + + Args: + x: Data point. jnp.float32 tensor, without batch dimension + rng_key: JAX prng key + bias: Whether or not to use bias params (static) + kernel_init: Init function for kernels + eval_mode: Enable eval mode. Disables noise parameters. + + Returns: + The transformed output. JNP tensor. + """ + + def mu_init(key, shape): + # Initialization of mean noise parameters (Section 3.2) + low = -1 / jnp.power(x.shape[-1], 0.5) + high = 1 / jnp.power(x.shape[-1], 0.5) + return random.uniform(key, minval=low, maxval=high, shape=shape) + + def sigma_init(key, shape, dtype=jnp.float32): # pylint: disable=unused-argument + # Initialization of sigma noise parameters (Section 3.2) + return jnp.ones(shape, dtype) * (0.5 / onp.sqrt(x.shape[-1])) + + # Factored gaussian noise in (10) and (11) in Fortunato et al. (2018). + p = NoisyNetwork.sample_noise(rng_key, [x.shape[-1], 1]) + q = NoisyNetwork.sample_noise(rng_key, [1, self.features]) + f_p = NoisyNetwork.f(p) + f_q = NoisyNetwork.f(q) + w_epsilon = f_p * f_q + b_epsilon = jnp.squeeze(f_q) + + # See (8) and (9) in Fortunato et al. (2018) for output computation. + w_mu = self.param('kernel', mu_init, (x.shape[-1], self.features)) + w_sigma = self.param('kernell', sigma_init, (x.shape[-1], self.features)) + w_epsilon = jnp.where( + eval_mode, + onp.zeros(shape=(x.shape[-1], self.features), dtype=onp.float32), + w_epsilon, + ) + w = w_mu + jnp.multiply(w_sigma, w_epsilon) + ret = jnp.matmul(x, w) + + b_epsilon = jnp.where( + eval_mode, + onp.zeros(shape=(self.features,), dtype=onp.float32), + b_epsilon, + ) + b_mu = self.param('bias', mu_init, (self.features,)) + b_sigma = self.param('biass', sigma_init, (self.features,)) + b = b_mu + jnp.multiply(b_sigma, b_epsilon) + return jnp.where(bias, ret + b, ret) + + +# -------------------------- < RainbowNetwork >--------------------------------- + + +class NoStatsBatchNorm(nn.Module): + """A version of BatchNorm that does not track running statistics. + + For use in places where this functionality is not available in Jax. + Attributes: + axis: the feature or non-batch axis of the input. + epsilon: a small float added to variance to avoid dividing by zero. + dtype: the dtype of the computation (default: float32). + use_bias: if True, bias (beta) is added. + use_scale: if True, multiply by scale (gamma). When the next layer is linear + (also e.g. nn.relu), this can be disabled since the scaling will be done + by the next layer. + bias_init: initializer for bias, by default, zero. + scale_init: initializer for scale, by default, one. + axis_name: the axis name used to combine batch statistics from multiple + devices. See `jax.pmap` for a description of axis names (default: None). + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, `[[0, 1], [2, 3]]` would independently batch-normalize over the + examples on the first two and last two devices. See `jax.lax.psum` for + more details. + """ + + use_running_average: Optional[bool] = None + axis: int = -1 + epsilon: float = 1e-5 + dtype: Dtype = jnp.float32 + use_bias: bool = True + use_scale: bool = True + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros + scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + axis_name: Optional[str] = None + axis_index_groups: Any = None + + @nn.compact + def __call__(self, x, use_running_average: Optional[bool] = None): + """Normalizes the input using batch statistics. + + NOTE: + During initialization (when parameters are mutable) the running average + of the batch statistics will not be updated. Therefore, the inputs + fed during initialization don't need to match that of the actual input + distribution and the reduction axis (set with `axis_name`) does not have + to exist. + Args: + x: the input to be normalized. + use_running_average: if true, the statistics stored in batch_stats will be + used instead of computing the batch statistics on the input. + + Returns: + Normalized inputs (the same shape as inputs). + """ + x = jnp.asarray(x, jnp.float32) + axis = self.axis if isinstance(self.axis, tuple) else (self.axis,) + axis = _absolute_dims(x.ndim, axis) + feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) + reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) + reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) + + # see NOTE above on initialization behavior + initializing = self.is_mutable_collection('params') + + mean = jnp.mean(x, axis=reduction_axis, keepdims=False) + mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) + if self.axis_name is not None and not initializing: + concatenated_mean = jnp.concatenate([mean, mean2]) + mean, mean2 = jnp.split( + lax.pmean( + concatenated_mean, + axis_name=self.axis_name, + axis_index_groups=self.axis_index_groups, + ), + 2, + ) + var = mean2 - lax.square(mean) + + y = x - mean.reshape(feature_shape) + mul = lax.rsqrt(var + self.epsilon) + if self.use_scale: + scale = self.param( + 'scale', self.scale_init, reduced_feature_shape + ).reshape(feature_shape) + mul = mul * scale + y = y * mul + if self.use_bias: + bias = self.param('bias', self.bias_init, reduced_feature_shape).reshape( + feature_shape + ) + y = y + bias + return jnp.asarray(y, self.dtype) + + +def feature_layer(noisy, features): + """Network feature layer depending on whether noisy_nets are used on or not.""" + if noisy: + net = NoisyNetwork(features=features) + else: + net = nn.Dense(features, kernel_init=nn.initializers.xavier_uniform()) + + def apply(x, key, eval_mode): + if noisy: + return net(x, key, True, None, eval_mode) # pytype: disable=wrong-arg-count + else: + return net(x) + + return net, apply + + +def renormalize(tensor, has_batch=False): + shape = tensor.shape + if not has_batch: + tensor = jnp.expand_dims(tensor, 0) + tensor = tensor.reshape(tensor.shape[0], -1) + max_val = jnp.max(tensor, axis=-1, keepdims=True) + min_val = jnp.min(tensor, axis=-1, keepdims=True) + return ((tensor - min_val) / (max_val - min_val + 1e-5)).reshape(*shape) + + +class ConvTMCell(nn.Module): + """MuZero-style transition model cell, used for SPR. + + Attributes: + num_actions: how many actions are possible (shape of one-hot vector) + latent_dim: number of channels in representation. + renormalize: whether or not to apply renormalization. + """ + + num_actions: int + latent_dim: int + renormalize: bool + + def setup(self): + self.bn = NoStatsBatchNorm(axis=-1, axis_name='batch') + + @nn.compact + def __call__(self, x, action, eval_mode=False, key=None): + sizes = [self.latent_dim, self.latent_dim] + kernel_sizes = [3, 3] + stride_sizes = [1, 1] + + action_onehot = jax.nn.one_hot(action, self.num_actions) + action_onehot = jax.lax.broadcast(action_onehot, (x.shape[-3], x.shape[-2])) + x = jnp.concatenate([x, action_onehot], -1) + for layer in range(1): + x = nn.Conv( + features=sizes[layer], + kernel_size=(kernel_sizes[layer], kernel_sizes[layer]), + strides=(stride_sizes[layer], stride_sizes[layer]), + kernel_init=nn.initializers.xavier_uniform(), + )(x) + x = nn.relu(x) + x = nn.Conv( + features=sizes[-1], + kernel_size=(kernel_sizes[-1], kernel_sizes[-1]), + strides=(stride_sizes[-1], stride_sizes[-1]), + kernel_init=nn.initializers.xavier_uniform(), + )(x) + x = nn.relu(x) + + if self.renormalize: + x = renormalize(x) + + return x, x + + +class RainbowCNN(nn.Module): + """A Jax implementation of the standard 3-layer CNN used in Atari. + + Attributes: + padding: which padding style to use. Defaults to SAME, which yields larger + final latents. + """ + + padding: Any = 'SAME' + + stack_sizes: Tuple[int, ...] = (32, 64, 64) + + @nn.compact + def __call__(self, x): + # x = x[None, Ellipsis] + hidden_sizes = self.stack_sizes + kernel_sizes = [8, 4, 3] + stride_sizes = [4, 2, 1] + for layer in range(3): + x = nn.Conv( + features=hidden_sizes[layer], + kernel_size=(kernel_sizes[layer], kernel_sizes[layer]), + strides=(stride_sizes[layer], stride_sizes[layer]), + kernel_init=nn.initializers.xavier_uniform(), + padding=self.padding, + )(x) + x = nn.relu(x) # flatten + return x + + +class TransitionModel(nn.Module): + """A Jax implementation of the SPR transition model, leveraging scan. + + Attributes: + num_actions: How many possible actions exist. + latent_dim: Output size. + renormalize: Whether or not to apply renormalization. + """ + + num_actions: int + latent_dim: int + renormalize: bool + + @nn.compact + def __call__(self, x, action): + scan = nn.scan( + ConvTMCell, + in_axes=0, + out_axes=0, + variable_broadcast=['params'], + split_rngs={'params': False}, + )( + latent_dim=self.latent_dim, + num_actions=self.num_actions, + renormalize=self.renormalize, + ) + return scan(x, action) + + +@gin.configurable +class SPRNetwork(nn.Module): + """Jax Rainbow network for Full Rainbow. + + Attributes: + num_actions: The number of actions the agent can take at any state. + num_atoms: The number of buckets of the value function distribution. + noisy: Whether to use noisy networks. + dueling: Whether to use dueling network architecture. + distributional: Whether to use distributional RL. + """ + + num_actions: int + num_atoms: int + noisy: bool + dueling: bool + distributional: bool + renormalize: bool = True + padding: Any = 'SAME' + inputs_preprocessed: bool = True + project_relu: bool = False + + def setup(self): + self.transition_model = TransitionModel( + num_actions=self.num_actions, + latent_dim=64, + renormalize=self.renormalize, + ) + self.projection, self.apply_projection = feature_layer(self.noisy, 512) + self.predictor = nn.Dense(512) + self.encoder = RainbowCNN(stack_sizes=(32, 64, 64)) + + def encode(self, x): + latent = self.encoder(x) + if self.renormalize: + latent = renormalize(latent) + return latent + + def project(self, x, key, eval_mode): + projected = self.apply_projection(x, key=key, eval_mode=eval_mode) + if self.project_relu: + projected = nn.relu(projected) + return projected + + @functools.partial(jax.vmap, in_axes=(None, 0, None, None)) + def spr_predict(self, x, key, eval_mode): + projected = self.apply_projection(x, key=key, eval_mode=eval_mode) + if self.project_relu: + return nn.relu(self.predictor(nn.relu(projected))) + else: + return self.predictor(projected) + + def spr_rollout(self, latent, actions, key): + _, pred_latents = self.transition_model(latent, actions) + predictions = self.spr_predict( + pred_latents.reshape(pred_latents.shape[0], -1), key, True + ) + return predictions + + @nn.compact + def __call__( + self, + x, + support, + actions=None, + do_rollout=False, + eval_mode=False, + key=None, + ): + if not self.inputs_preprocessed: + x = preprocess_atari_inputs(x) + + # Generate a random number generation key if not provided + if key is None: + key = random.PRNGKey(int(time.time() * 1e6)) + + latent = self.encode(x) + x = self.apply_projection( + latent.reshape(-1), key, eval_mode + ) # Single hidden layer of size 512 + x = nn.relu(x) + + if self.dueling: + key, rng1, rng2 = random.split(key, 3) + _, adv_net = feature_layer(self.noisy, self.num_actions * self.num_atoms) + _, val_net = feature_layer(self.noisy, self.num_atoms) + adv = adv_net(x, rng1, eval_mode) + value = val_net(x, rng2, eval_mode) + adv = adv.reshape((self.num_actions, self.num_atoms)) + value = value.reshape((1, self.num_atoms)) + logits = value + (adv - (jnp.mean(adv, -2, keepdims=True))) + else: + key, rng1 = random.split(key, 2) + _, adv_net = feature_layer(self.noisy, self.num_actions * self.num_atoms) + x = adv_net(x, rng1, eval_mode) + logits = x.reshape((self.num_actions, self.num_atoms)) + + if do_rollout: + latent = self.spr_rollout(latent, actions, key) + + if self.distributional: + probabilities = jnp.squeeze(nn.softmax(logits)) + q_values = jnp.squeeze(jnp.sum(support * probabilities, axis=-1)) + return SPROutputType(q_values, logits, probabilities, latent) + + q_values = jnp.squeeze(logits) + return SPROutputType(q_values, None, None, latent) diff --git a/dopamine/labs/atari_100k/train.py b/dopamine/labs/atari_100k/train.py index d6064b15..13deb41e 100644 --- a/dopamine/labs/atari_100k/train.py +++ b/dopamine/labs/atari_100k/train.py @@ -26,26 +26,53 @@ from dopamine.discrete_domains import run_experiment from dopamine.discrete_domains import train as base_train from dopamine.labs.atari_100k import atari_100k_rainbow_agent +from dopamine.labs.atari_100k import atari_100k_runner from dopamine.labs.atari_100k import eval_run_experiment +from dopamine.labs.atari_100k import spr_agent import numpy as np import tensorflow as tf +.learning.deepmind.xmanager2.client.google as xm # pylint: disable=unused-import + + FLAGS = flags.FLAGS -AGENTS = ['DER', 'DrQ', 'OTRainbow', 'DrQ_eps'] +AGENTS = ['DER', 'DrQ', 'OTRainbow', 'DrQ_eps', 'SPR'] # flags are defined when importing run_xm_preprocessing flags.DEFINE_enum('agent', 'DER', AGENTS, 'Name of the agent.') flags.DEFINE_integer('run_number', 1, 'Run number.') flags.DEFINE_boolean('max_episode_eval', True, 'Whether to use `MaxEpisodeEvalRunner` or not.') +flags.DEFINE_boolean( + 'legacy_runner', + False, + ( + 'Whether to use the legacy MaxEpisodeEvalRunner.' + ' This runner does not run parallel evaluation environments and may be' + ' easier to understand, but will be noticeably slower. It also does not' + ' guarantee that a precise number of training steps will be collected,' + ' which clashes with the technical definition of Atari 100k.' + ), +) + + +def create_agent( + sess, # pylint: disable=unused-argument + environment, + seed, + agent_name: str, + summary_writer=None, +): + """Helper function for creating full rainbow-based Atari 100k agent.""" + if agent_name == 'SPR': + return spr_agent.SPRAgent( + num_actions=environment.action_space.n, + seed=seed, + summary_writer=summary_writer, + ) -def create_agent(sess, # pylint: disable=unused-argument - environment, - seed, - summary_writer=None): - """Helper function for creating full rainbow-based Atari 100k agent.""" return atari_100k_rainbow_agent.Atari100kRainbowAgent( num_actions=environment.action_space.n, seed=seed, @@ -73,13 +100,17 @@ def main(unused_argv): gin_files, gin_bindings = FLAGS.gin_files, FLAGS.gin_bindings run_experiment.load_gin_configs(gin_files, gin_bindings) # Set the Jax agent seed using the run number - create_agent_fn = functools.partial(create_agent, seed=FLAGS.run_number) - if FLAGS.max_episode_eval: + create_agent_fn = functools.partial( + create_agent, seed=FLAGS.run_number, agent_name=FLAGS.agent + ) + if FLAGS.legacy_runner: runner_fn = eval_run_experiment.MaxEpisodeEvalRunner logging.info('Using MaxEpisodeEvalRunner for evaluation.') runner = runner_fn(base_dir, create_agent_fn) else: - runner = run_experiment.Runner(base_dir, create_agent_fn) + runner = atari_100k_runner.DataEfficientAtariRunner( + base_dir, create_agent_fn + ) runner.run_experiment()