diff --git a/examples/ppo/README.md b/examples/ppo/README.md new file mode 100644 index 0000000000..318955b985 --- /dev/null +++ b/examples/ppo/README.md @@ -0,0 +1,47 @@ +# Proximal Policy Optimization + +Uses the Proximal Policy Optimization algorithm ([Schulman et al., 2017](https://arxiv.org/abs/1707.06347)) +to learn playing Atari games. + +## Requirements + +This example depends on the `gym`, `opencv-python` and `atari-py` packages +in addition to `jax` and `flax`. + +## Supported setups + +The example should run with other configurations and hardware, but was explicitly +tested on the following: + +| Hardware | Game | Training time | Total frames seen | TensorBoard.dev | +| --- | --- | --- | --- | --- | +| 1x V100 GPU | Qbert | 9h 27m 8s | 40M | [2020-09-30](https://tensorboard.dev/experiment/1pacpbxxRz2di3NIOFkHoA/#scalars) | + +## How to run + +Running `python ppo_main.py` will run the example with default +(hyper)parameters, i.e. for 40M frames on the Pong game. + +By default logging info and checkpoints will be stored in `/tmp/ppo_training` +directory. This can be overriden as follows: + +```python ppo_main.py --logdir=/my_fav_directory``` + +You can also override the default (hyper)parameters, for example + +```python ppo_main.py --config.game=Seaquest --config.total_frames=20000000 --config.decaying_lr_and_clip_param=False --logdir=/tmp/seaquest``` + +will train the model on 20M Seaquest frames with constant (i.e. not linearly +decaying) learning rate and PPO clipping parameter. Checkpoints and tensorboard +files will be saved in `/tmp/seaquest`. + +Unit tests can be run using `python ppo_lib_test.py`. + +## How to run on Google Cloud TPU + +It is also possible to run this code on Google Cloud TPU. For detailed +instructions on the required setup, please refer to the [WMT example readme](https://github.com/google/flax/tree/master/examples/wmt). + +## Owners + +Jonathan Heek @jheek, Wojciech Rzadkowski @wrzadkow \ No newline at end of file diff --git a/examples/ppo/agent.py b/examples/ppo/agent.py new file mode 100644 index 0000000000..4093df59bc --- /dev/null +++ b/examples/ppo/agent.py @@ -0,0 +1,71 @@ +# Copyright 2020 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent utilities, incl. choosing the move and running in separate process.""" + +import multiprocessing +import collections +import jax +import numpy as onp + +import env_utils + +@jax.jit +def policy_action(model, state): + """Forward pass of the network.""" + out = model(state) + return out + + +ExpTuple = collections.namedtuple( + 'ExpTuple', ['state', 'action', 'reward', 'value', 'log_prob', 'done']) + + +class RemoteSimulator: + """Wrap functionality for an agent emulating Atari in a separate process. + + An object of this class is created for every agent. + """ + + def __init__(self, game: str): + """Start the remote process and create Pipe() to communicate with it.""" + parent_conn, child_conn = multiprocessing.Pipe() + self.proc = multiprocessing.Process( + target=rcv_action_send_exp, args=(child_conn, game)) + self.conn = parent_conn + self.proc.start() + + +def rcv_action_send_exp(conn, game: str): + """Run the remote agents. + + Receive action from the main learner, perform one step of simulation and + send back collected experience. + """ + env = env_utils.create_env(game, clip_rewards=True) + while True: + obs = env.reset() + done = False + # Observations fetched from Atari env need additional batch dimension. + state = obs[None, ...] + while not done: + conn.send(state) + action = conn.recv() + obs, reward, done, _ = env.step(action) + next_state = obs[None, ...] if not done else None + experience = (state, action, reward, done) + conn.send(experience) + if done: + break + state = next_state diff --git a/examples/ppo/default_config.py b/examples/ppo/default_config.py new file mode 100644 index 0000000000..a0b0dd9306 --- /dev/null +++ b/examples/ppo/default_config.py @@ -0,0 +1,54 @@ +# Copyright 2020 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Definitions of default hyperparameters.""" + +import ml_collections + +def get_config(): + """Get the default configuration. + + The default hyperparameters originate from PPO paper arXiv:1707.06347 + and openAI baselines 2:: + https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py + """ + config = ml_collections.ConfigDict() + # The Atari game used. + config.game = 'Pong' + # Total number of frames seen during training. + config.total_frames = 40000000 + # The learning rate for the Adam optimizer. + config.learning_rate = 2.5e-4 + # Batch size used in training. + config.batch_size = 256 + # Number of agents playing in parallel. + config.num_agents = 8 + # Number of steps each agent performs in one policy unroll. + config.actor_steps = 128 + # Number of training epochs per each unroll of the policy. + config.num_epochs = 3 + # RL discount parameter. + config.gamma = 0.99 + # Generalized Advantage Estimation parameter. + config.lambda_ = 0.95 + # The PPO clipping parameter used to clamp ratios in loss function. + config.clip_param = 0.1 + # Weight of value function loss in the total loss. + config.vf_coeff = 0.5 + # Weight of entropy bonus in the total loss. + config.entropy_coeff = 0.01 + # Linearly decay learning rate and clipping parameter to zero during + # the training. + config.decaying_lr_and_clip_param = True + return config diff --git a/examples/ppo/env_utils.py b/examples/ppo/env_utils.py new file mode 100644 index 0000000000..d7004dd1cf --- /dev/null +++ b/examples/ppo/env_utils.py @@ -0,0 +1,81 @@ +# Copyright 2020 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for handling the Atari environment.""" + +import collections +import gym +import numpy as onp + +import seed_rl_atari_preprocessing + +class ClipRewardEnv(gym.RewardWrapper): + """Adapted from OpenAI baselines. + + github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py + """ + + def __init__(self, env): + gym.RewardWrapper.__init__(self, env) + + def reward(self, reward): + """Bin reward to {+1, 0, -1} by its sign.""" + return onp.sign(reward) + +class FrameStack: + """Implements stacking of `num_frames` last frames of the game. + + Wraps an AtariPreprocessing object. + """ + + def __init__( + self, + preproc: seed_rl_atari_preprocessing.AtariPreprocessing, + num_frames: int): + self.preproc = preproc + self.num_frames = num_frames + self.frames = collections.deque(maxlen=num_frames) + + def reset(self): + ob = self.preproc.reset() + for _ in range(self.num_frames): + self.frames.append(ob) + return self._get_array() + + def step(self, action: int): + ob, reward, done, info = self.preproc.step(action) + self.frames.append(ob) + return self._get_array(), reward, done, info + + def _get_array(self): + assert len(self.frames) == self.num_frames + return onp.concatenate(self.frames, axis=-1) + +def create_env(game: str, clip_rewards: bool): + """Create a FrameStack object that serves as environment for the `game`.""" + env = gym.make(game) + if clip_rewards: + env = ClipRewardEnv(env) # bin rewards to {-1., 0., 1.} + preproc = seed_rl_atari_preprocessing.AtariPreprocessing(env) + stack = FrameStack(preproc, num_frames=4) + return stack + +def get_num_actions(game: str): + """Get the number of possible actions of a given Atari game. + + This determines the number of outputs in the actor part of the + actor-critic model. + """ + env = gym.make(game) + return env.action_space.n diff --git a/examples/ppo/models.py b/examples/ppo/models.py new file mode 100644 index 0000000000..d51a065fa6 --- /dev/null +++ b/examples/ppo/models.py @@ -0,0 +1,67 @@ +# Copyright 2020 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Class and functions to define and initialize the actor-critic model.""" + +import numpy as onp +import flax +from flax import nn +import jax.numpy as jnp + +class ActorCritic(flax.nn.Module): + """Class defining the actor-critic model.""" + + def apply(self, x, num_outputs): + """Define the convolutional network architecture. + + Architecture originates from "Human-level control through deep reinforcement + learning.", Nature 518, no. 7540 (2015): 529-533. + Note that this is different than the one from "Playing atari with deep + reinforcement learning." arxiv.org/abs/1312.5602 (2013) + """ + dtype = jnp.float32 + x = x.astype(dtype) / 255. + x = nn.Conv(x, features=32, kernel_size=(8, 8), + strides=(4, 4), name='conv1', + dtype=dtype) + x = nn.relu(x) + x = nn.Conv(x, features=64, kernel_size=(4, 4), + strides=(2, 2), name='conv2', + dtype=dtype) + x = nn.relu(x) + x = nn.Conv(x, features=64, kernel_size=(3, 3), + strides=(1, 1), name='conv3', + dtype=dtype) + x = nn.relu(x) + x = x.reshape((x.shape[0], -1)) # flatten + x = nn.Dense(x, features=512, name='hidden', dtype=dtype) + x = nn.relu(x) + # Network used to both estimate policy (logits) and expected state value. + # See github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py + logits = nn.Dense(x, features=num_outputs, name='logits', dtype=dtype) + policy_log_probabilities = nn.log_softmax(logits) + value = nn.Dense(x, features=1, name='value', dtype=dtype) + return policy_log_probabilities, value + +def create_model(key: onp.ndarray, num_outputs: int): + input_dims = (1, 84, 84, 4) # (minibatch, height, width, stacked frames) + module = ActorCritic.partial(num_outputs=num_outputs) + _, initial_par = module.init_by_shape(key, [(input_dims, jnp.float32)]) + model = flax.nn.Model(module, initial_par) + return model + +def create_optimizer(model: nn.base.Model, learning_rate: float): + optimizer_def = flax.optim.Adam(learning_rate) + optimizer = optimizer_def.create(model) + return optimizer diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py new file mode 100644 index 0000000000..0451791c70 --- /dev/null +++ b/examples/ppo/ppo_lib.py @@ -0,0 +1,265 @@ +# Copyright 2020 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library file which executes the PPO training""" + +import functools +from typing import Tuple, List +import jax +import jax.random +import jax.numpy as jnp +import numpy as onp +import flax +from flax import nn +from flax.metrics import tensorboard +from flax.training import checkpoints +import ml_collections + +import agent +import test_episodes + +@functools.partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1) +@jax.jit +def gae_advantages( + rewards: onp.ndarray, + terminal_masks: onp.ndarray, + values: onp.ndarray, + discount: float, + gae_param: float): + """Use Generalized Advantage Estimation (GAE) to compute advantages. + + As defined by eqs. (11-12) in PPO paper arXiv: 1707.06347. Implementation uses + key observation that A_{t} = delta_t + gamma*lambda*A_{t+1}. + + Args: + rewards: array shaped (actor_steps, num_agents), rewards from the game + terminal_masks: array shaped (actor_steps, num_agents), zeros for terminal + and ones for non-terminal states + values: array shaped (actor_steps, num_agents), values estimated by critic + discount: RL discount usually denoted with gamma + gae_param: GAE parameter usually denoted with lambda + + Returns: + advantages: calculated advantages shaped (actor_steps, num_agents) + """ + assert rewards.shape[0] + 1 == values.shape[0], ('One more value needed; Eq. ' + '(12) in PPO paper requires ' + 'V(s_{t+1}) for delta_t') + advantages = [] + gae = 0. + for t in reversed(range(len(rewards))): + # Masks used to set next state value to 0 for terminal states. + value_diff = discount * values[t + 1] * terminal_masks[t] - values[t] + delta = rewards[t] + value_diff + # Masks[t] used to ensure that values before and after a terminal state + # are independent of each other. + gae = delta + discount * gae_param * terminal_masks[t] * gae + advantages.append(gae) + advantages = advantages[::-1] + return jnp.array(advantages) + +@functools.partial(jax.jit, static_argnums=(6)) +def train_step( + optimizer: flax.optim.base.Optimizer, + trajectories: Tuple, + clip_param: float, + vf_coeff: float, + entropy_coeff: float, + lr: float, + batch_size: int): + """Compilable train step. + + Runs an entire epoch of training (i.e. the loop over + minibatches within an epoch is included here for performance reasons). + + Args: + optimizer: optimizer for the actor-critic model + trajectories: Tuple of the following five elements forming the experience: + states: shape (steps_per_agent*num_agents, 84, 84, 4) + actions: shape (steps_per_agent*num_agents, 84, 84, 4) + old_log_probs: shape (steps_per_agent*num_agents, ) + returns: shape (steps_per_agent*num_agents, ) + advantages: (steps_per_agent*num_agents, ) + clip_param: the PPO clipping parameter used to clamp ratios in loss function + vf_coeff: weighs value function loss in total loss + entropy_coeff: weighs entropy bonus in the total loss + lr: learning rate, varies between optimization steps + if decaying_lr_and_clip_param is set to true + batch_size: the minibatch size, static argument + + Returns: + optimizer: new optimizer after the parameters update + loss: loss summed over training steps + """ + def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): + states, actions, old_log_probs, returns, advantages = minibatch + log_probs, values = model(states) + values = values[:, 0] # Convert shapes: (batch, 1) to (batch, ). + probs = jnp.exp(log_probs) + entropy = jnp.sum(-probs*log_probs, axis=1).mean() + log_probs_act_taken = jax.vmap(lambda lp, a: lp[a])(log_probs, actions) + ratios = jnp.exp(log_probs_act_taken - old_log_probs) + # Advantage normalization (following the OpenAI baselines). + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + PG_loss = ratios * advantages + clipped_loss = advantages * jax.lax.clamp(1. - clip_param, ratios, + 1. + clip_param) + PPO_loss = -jnp.mean(jnp.minimum(PG_loss, clipped_loss), axis=0) + value_loss = jnp.mean(jnp.square(returns - values), axis=0) + return PPO_loss + vf_coeff*value_loss - entropy_coeff*entropy + + iterations = trajectories[0].shape[0] // batch_size + trajectories = jax.tree_map( + lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trajectories) + loss = 0. + for batch in zip(*trajectories): + grad_fn = jax.value_and_grad(loss_fn) + l, grad = grad_fn(optimizer.target, batch, clip_param, vf_coeff, + entropy_coeff) + loss += l + optimizer = optimizer.apply_gradient(grad, learning_rate=lr) + return optimizer, loss + +def get_experience( + model: nn.base.Model, + simulators: List[agent.RemoteSimulator], + steps_per_actor: int): + """Collect experience from agents. + + Runs `steps_per_actor` time steps of the game for each of the `simulators`. + """ + all_experience = [] + # Range up to steps_per_actor + 1 to get one more value needed for GAE. + for _ in range(steps_per_actor + 1): + states = [] + for sim in simulators: + state = sim.conn.recv() + states.append(state) + states = onp.concatenate(states, axis=0) + log_probs, values = agent.policy_action(model, states) + log_probs, values = jax.device_get((log_probs, values)) + probs = onp.exp(onp.array(log_probs)) + for i, sim in enumerate(simulators): + probabilities = probs[i] + action = onp.random.choice(probs.shape[1], p=probabilities) + sim.conn.send(action) + experiences = [] + for i, sim in enumerate(simulators): + state, action, reward, done = sim.conn.recv() + value = values[i, 0] + log_prob = log_probs[i][action] + sample = agent.ExpTuple(state, action, reward, value, log_prob, done) + experiences.append(sample) + all_experience.append(experiences) + return all_experience + +def process_experience( + experience: List[List[agent.ExpTuple]], + actor_steps: int, + num_agents: int, + gamma: float, + lambda_: float): + """Process experience for training, including advantage estimation. + + Args: + experience: collected from agents in the form of nested lists/namedtuple + actor_steps: number of steps each agent has completed + num_agents: number of agents that collected experience + gamma: dicount parameter + lambda_: GAE parameter + + Returns: + trajectories: trajectories readily accessible for `train_step()` function + """ + obs_shape = (84, 84, 4) + exp_dims = (actor_steps, num_agents) + values_dims = (actor_steps + 1, num_agents) + states = onp.zeros(exp_dims + obs_shape, dtype=onp.float32) + actions = onp.zeros(exp_dims, dtype=onp.int32) + rewards = onp.zeros(exp_dims, dtype=onp.float32) + values = onp.zeros(values_dims, dtype=onp.float32) + log_probs = onp.zeros(exp_dims, dtype=onp.float32) + dones = onp.zeros(exp_dims, dtype=onp.float32) + + for t in range(len(experience) - 1): # experience[-1] only for next_values + for agent_id, exp_agent in enumerate(experience[t]): + states[t, agent_id, ...] = exp_agent.state + actions[t, agent_id] = exp_agent.action + rewards[t, agent_id] = exp_agent.reward + values[t, agent_id] = exp_agent.value + log_probs[t, agent_id] = exp_agent.log_prob + # Dones need to be 0 for terminal states. + dones[t, agent_id] = float(not exp_agent.done) + for a in range(num_agents): + values[-1, a] = experience[-1][a].value + advantages = gae_advantages(rewards, dones, values, gamma, lambda_) + returns = advantages + values[:-1, :] + # After preprocessing, concatenate data from all agents. + trajectories = (states, actions, log_probs, returns, advantages) + trajectories = tuple(map( + lambda x: onp.reshape( + x, (num_agents * actor_steps,) + x.shape[2:]), + trajectories)) + return trajectories + +def train( + optimizer: flax.optim.base.Optimizer, + config: ml_collections.ConfigDict, + model_dir: str): + """Main training loop. + + Args: + optimizer: optimizer for the actor-critic model + config: object holding hyperparameters and the training information + model_dir: path to dictionary where checkpoints and logging info are stored + + Returns: + optimizer: the trained optimizer + """ + game = config.game + 'NoFrameskip-v4' + simulators = [agent.RemoteSimulator(game) + for _ in range(config.num_agents)] + summary_writer = tensorboard.SummaryWriter(model_dir) + loop_steps = config.total_frames // (config.num_agents * config.actor_steps) + log_frequency = 40 + checkpoint_frequency = 500 + + + for s in range(loop_steps): + # Bookkeeping and testing. + if s % log_frequency == 0: + score = test_episodes.policy_test(1, optimizer.target, game) + frames = s * config.num_agents * config.actor_steps + summary_writer.scalar('game_score', score, frames) + print(f'Step {s}:\nframes seen {frames}\nscore {score}\n\n') + if s % checkpoint_frequency == 0: + checkpoints.save_checkpoint(model_dir, optimizer, s) + + # Core training code. + alpha = 1. - s/loop_steps if config.decaying_lr_and_clip_param else 1. + all_experiences = get_experience( + optimizer.target, simulators, config.actor_steps) + trajectories = process_experience( + all_experiences, config.actor_steps, config.num_agents, config.gamma, + config.lambda_) + lr = config.learning_rate * alpha + clip_param = config.clip_param * alpha + for e in range(config.num_epochs): + permutation = onp.random.permutation( + config.num_agents * config.actor_steps) + trajectories = tuple(map(lambda x: x[permutation], trajectories)) + optimizer, loss = train_step( + optimizer, trajectories, clip_param, config.vf_coeff, + config.entropy_coeff, lr, config.batch_size) + return optimizer diff --git a/examples/ppo/ppo_lib_test.py b/examples/ppo/ppo_lib_test.py new file mode 100644 index 0000000000..c48bbb9e52 --- /dev/null +++ b/examples/ppo/ppo_lib_test.py @@ -0,0 +1,134 @@ +# Copyright 2020 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the PPO example.""" + +import jax +import flax +from flax import nn +import numpy as onp +import numpy.testing as onp_testing +from absl.testing import absltest + +import ppo_lib +import env_utils +import models + +# test GAE +class TestGAE(absltest.TestCase): + def test_gae_shape_on_random(self): + # create random data, simulating 4 parallel envs and 20 time_steps + envs, steps = 10, 100 + rewards = onp.random.choice([-1., 0., 1.], size=(steps, envs), + p=[0.01, 0.98, 0.01]) + terminal_masks = onp.ones(shape=(steps, envs), dtype=onp.float64) + values = onp.random.random(size=(steps + 1, envs)) + discount = 0.99 + gae_param = 0.95 + adv = ppo_lib.gae_advantages(rewards, terminal_masks, values, discount, + gae_param) + self.assertEqual(adv.shape, (steps, envs)) + def test_gae_hardcoded(self): + #test on small example that can be verified by hand + rewards = onp.array([[1., 0.], [0., 0.], [-1., 1.]]) + #one of the two episodes terminated in the middle + terminal_masks = onp.array([[1., 1.], [0., 1.], [1., 1.]]) + values = onp.array([[1., 1.], [1., 1.], [1., 1.], [1., 1.]]) + discount = 0.5 + gae_param = 0.25 + correct_gae = onp.array([[0.375, -0.5546875], [-1., -0.4375], [-1.5, 0.5]]) + actual_gae = ppo_lib.gae_advantages(rewards, terminal_masks, values, + discount, gae_param) + onp_testing.assert_allclose(actual_gae, correct_gae) +# test environment and preprocessing +class TestEnvironmentPreprocessing(absltest.TestCase): + def choose_random_game(self): + games = ['BeamRider', 'Breakout', 'Pong', + 'Qbert', 'Seaquest', 'SpaceInvaders'] + ind = onp.random.choice(len(games)) + return games[ind] + "NoFrameskip-v4" + + def test_creation(self): + frame_shape = (84, 84, 4) + game = self.choose_random_game() + env = env_utils.create_env(game, clip_rewards=True) + obs = env.reset() + self.assertTrue(obs.shape == frame_shape) + + def test_step(self): + frame_shape = (84, 84, 4) + game = self.choose_random_game() + env = env_utils.create_env(game, clip_rewards=False) + obs = env.reset() + actions = [1, 2, 3, 0] + for a in actions: + obs, reward, done, info = env.step(a) + self.assertTrue(obs.shape == frame_shape) + self.assertTrue(reward <= 1. and reward >= -1.) + self.assertTrue(isinstance(done, bool)) + self.assertTrue(isinstance(info, dict)) + +# test the model (creation and forward pass) +class TestModel(absltest.TestCase): + def choose_random_outputs(self): + return onp.random.choice([4, 5, 6, 7, 8, 9]) + + def test_model(self): + key = jax.random.PRNGKey(0) + key, subkey = jax.random.split(key) + outputs = self.choose_random_outputs() + model = models.create_model(subkey, outputs) + optimizer = models.create_optimizer(model, learning_rate=1e-3) + self.assertTrue(isinstance(model, nn.base.Model)) + self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer)) + test_batch_size, obs_shape = 10, (84, 84, 4) + random_input = onp.random.random(size=(test_batch_size,) + obs_shape) + log_probs, values = optimizer.target(random_input) + self.assertTrue(values.shape == (test_batch_size, 1)) + sum_probs = onp.sum(onp.exp(log_probs), axis=1) + self.assertTrue(sum_probs.shape == (test_batch_size, )) + onp_testing.assert_allclose(sum_probs, onp.ones((test_batch_size, )), + atol=1e-6) + +# test one optimization step +class TestOptimizationStep(absltest.TestCase): + def generate_random_data(self, num_actions): + data_len = 256 # equal to one default-sized batch + state_shape = (84, 84, 4) + states = onp.random.randint(0, 255, size=((data_len, ) + state_shape)) + actions = onp.random.choice(num_actions, size=data_len) + old_log_probs = onp.random.random(size=data_len) + returns = onp.random.random(size=data_len) + advantages = onp.random.random(size=data_len) + return states, actions, old_log_probs, returns, advantages + + def test_optimization_step(self): + num_outputs = 4 + trn_data = self.generate_random_data(num_actions=num_outputs) + clip_param = 0.1 + vf_coeff = 0.5 + entropy_coeff = 0.01 + lr = 2.5e-4 + batch_size = 256 + key = jax.random.PRNGKey(0) + key, subkey = jax.random.split(key) + model = models.create_model(subkey, num_outputs) + optimizer = models.create_optimizer(model, learning_rate=lr) + optimizer, _ = ppo_lib.train_step( + optimizer, trn_data, clip_param, vf_coeff, entropy_coeff, lr, + batch_size) + self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer)) + +if __name__ == '__main__': + absltest.main() diff --git a/examples/ppo/ppo_main.py b/examples/ppo/ppo_main.py new file mode 100644 index 0000000000..6119550947 --- /dev/null +++ b/examples/ppo/ppo_main.py @@ -0,0 +1,49 @@ +# Copyright 2020 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from absl import flags +from absl import app +import jax +import jax.random +from ml_collections import config_flags + +import ppo_lib +import models +import env_utils + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + 'logdir', default='/tmp/ppo_training', + help=('Directory to save checkpoints and logging info.')) + +config_flags.DEFINE_config_file( + 'config', os.path.join(os.path.dirname(__file__), 'default_config.py'), + 'File path to the default configuration file.') + +def main(argv): + config = FLAGS.config + game = config.game + 'NoFrameskip-v4' + num_actions = env_utils.get_num_actions(game) + print(f'Playing {game} with {num_actions} actions') + key = jax.random.PRNGKey(0) + key, subkey = jax.random.split(key) + model = models.create_model(subkey, num_outputs=num_actions) + optimizer = models.create_optimizer(model, learning_rate=config.learning_rate) + del model + optimizer = ppo_lib.train(optimizer, config, FLAGS.logdir) + +if __name__ == '__main__': + app.run(main) diff --git a/examples/ppo/requirements.txt b/examples/ppo/requirements.txt new file mode 100644 index 0000000000..69d6538acf --- /dev/null +++ b/examples/ppo/requirements.txt @@ -0,0 +1,5 @@ +atari-py +gym +jax +jaxlib +opencv-python \ No newline at end of file diff --git a/examples/ppo/seed_rl_atari_preprocessing.py b/examples/ppo/seed_rl_atari_preprocessing.py new file mode 100644 index 0000000000..4c57ebb47e --- /dev/null +++ b/examples/ppo/seed_rl_atari_preprocessing.py @@ -0,0 +1,226 @@ +# Copyright 2020 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# coding=utf-8 +# Copyright 2019 The SEED Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A class implementing minimal Atari 2600 preprocessing. +Adapted from SEED RL, originally adapted from Dopamine. +""" + +from gym.spaces.box import Box +import numpy as np + +import cv2 + + +class AtariPreprocessing(object): + """A class implementing image preprocessing for Atari 2600 agents. + Specifically, this provides the following subset from the JAIR paper + (Bellemare et al., 2013) and Nature DQN paper (Mnih et al., 2015): + * Frame skipping (defaults to 4). + * Terminal signal when a life is lost (off by default). + * Grayscale and max-pooling of the last two frames. + * Downsample the screen to a square image (defaults to 84x84). + More generally, this class follows the preprocessing guidelines set down in + Machado et al. (2018), "Revisiting the Arcade Learning Environment: + Evaluation Protocols and Open Problems for General Agents". + It also provides random starting no-ops, which are used in the Rainbow, Apex + and R2D2 papers. + """ + + def __init__(self, environment, frame_skip=4, terminal_on_life_loss=False, + screen_size=84, max_random_noops=0): + """Constructor for an Atari 2600 preprocessor. + Args: + environment: Gym environment whose observations are preprocessed. + frame_skip: int, the frequency at which the agent experiences the game. + terminal_on_life_loss: bool, If True, the step() method returns + is_terminal=True whenever a life is lost. See Mnih et al. 2015. + screen_size: int, size of a resized Atari 2600 frame. + max_random_noops: int, maximum number of no-ops to apply at the beginning + of each episode to reduce determinism. These no-ops are applied at a + low-level, before frame skipping. + Raises: + ValueError: if frame_skip or screen_size are not strictly positive. + """ + if frame_skip <= 0: + raise ValueError('Frame skip should be strictly positive, got {}'. + format(frame_skip)) + if screen_size <= 0: + raise ValueError('Target screen size should be strictly positive, got {}'. + format(screen_size)) + + self.environment = environment + self.terminal_on_life_loss = terminal_on_life_loss + self.frame_skip = frame_skip + self.screen_size = screen_size + self.max_random_noops = max_random_noops + + obs_dims = self.environment.observation_space + # Stores temporary observations used for pooling over two successive + # frames. + self.screen_buffer = [ + np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8), + np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8) + ] + + self.game_over = False + self.lives = 0 # Will need to be set by reset(). + + @property + def observation_space(self): + # Return the observation space adjusted to match the shape of the processed + # observations. + return Box(low=0, high=255, shape=(self.screen_size, self.screen_size, 1), + dtype=np.uint8) + + @property + def action_space(self): + return self.environment.action_space + + @property + def reward_range(self): + return self.environment.reward_range + + @property + def metadata(self): + return self.environment.metadata + + def close(self): + return self.environment.close() + + def apply_random_noops(self): + """Steps self.environment with random no-ops.""" + if self.max_random_noops <= 0: + return + # Other no-ops implementations actually always do at least 1 no-op. We + # follow them. + no_ops = self.environment.np_random.randint(1, self.max_random_noops + 1) + for _ in range(no_ops): + _, _, game_over, _ = self.environment.step(0) + if game_over: + self.environment.reset() + + def reset(self): + """Resets the environment. + Returns: + observation: numpy array, the initial observation emitted by the + environment. + """ + self.environment.reset() + self.apply_random_noops() + + self.lives = self.environment.ale.lives() + self._fetch_grayscale_observation(self.screen_buffer[0]) + self.screen_buffer[1].fill(0) + return self._pool_and_resize() + + def render(self, mode): + """Renders the current screen, before preprocessing. + This calls the Gym API's render() method. + Args: + mode: Mode argument for the environment's render() method. + Valid values (str) are: + 'rgb_array': returns the raw ALE image. + 'human': renders to display via the Gym renderer. + Returns: + if mode='rgb_array': numpy array, the most recent screen. + if mode='human': bool, whether the rendering was successful. + """ + return self.environment.render(mode) + + def step(self, action): + """Applies the given action in the environment. + Remarks: + * If a terminal state (from life loss or episode end) is reached, this may + execute fewer than self.frame_skip steps in the environment. + * Furthermore, in this case the returned observation may not contain valid + image data and should be ignored. + Args: + action: The action to be executed. + Returns: + observation: numpy array, the observation following the action. + reward: float, the reward following the action. + is_terminal: bool, whether the environment has reached a terminal state. + This is true when a life is lost and terminal_on_life_loss, or when the + episode is over. + info: Gym API's info data structure. + """ + accumulated_reward = 0. + + for time_step in range(self.frame_skip): + # We bypass the Gym observation altogether and directly fetch the + # grayscale image from the ALE. This is a little faster. + _, reward, game_over, info = self.environment.step(action) + accumulated_reward += reward + + if self.terminal_on_life_loss: + new_lives = self.environment.ale.lives() + is_terminal = game_over or new_lives < self.lives + self.lives = new_lives + else: + is_terminal = game_over + + if is_terminal: + break + # We max-pool over the last two frames, in grayscale. + elif time_step >= self.frame_skip - 2: + t = time_step - (self.frame_skip - 2) + self._fetch_grayscale_observation(self.screen_buffer[t]) + + # Pool the last two observations. + observation = self._pool_and_resize() + + self.game_over = game_over + return observation, accumulated_reward, is_terminal, info + + def _fetch_grayscale_observation(self, output): + """Returns the current observation in grayscale. + The returned observation is stored in 'output'. + Args: + output: numpy array, screen buffer to hold the returned observation. + Returns: + observation: numpy array, the current observation in grayscale. + """ + self.environment.ale.getScreenGrayscale(output) + return output + + def _pool_and_resize(self): + """Transforms two frames into a Nature DQN observation. + For efficiency, the transformation is done in-place in self.screen_buffer. + Returns: + transformed_screen: numpy array, pooled, resized screen. + """ + # Pool if there are enough screens to do so. + if self.frame_skip > 1: + np.maximum(self.screen_buffer[0], self.screen_buffer[1], + out=self.screen_buffer[0]) + + transformed_image = cv2.resize(self.screen_buffer[0], + (self.screen_size, self.screen_size), + interpolation=cv2.INTER_LINEAR) + int_image = np.asarray(transformed_image, dtype=np.uint8) + return np.expand_dims(int_image, axis=2) diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py new file mode 100644 index 0000000000..157ce0c9f8 --- /dev/null +++ b/examples/ppo/test_episodes.py @@ -0,0 +1,51 @@ +# Copyright 2020 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test policy by playing a full Atari game.""" + +import itertools +import flax +import numpy as onp + +import env_utils +import agent + +def policy_test(n_episodes: int, model: flax.nn.base.Model, game: str): + """Perform a test of the policy in Atari environment. + + Args: + n_episodes: number of full Atari episodes to test on + model: the actor-critic model being tested + game: defines the Atari game to test on + + Returns: + total_reward: obtained score + """ + test_env = env_utils.create_env(game, clip_rewards=False) + for _ in range(n_episodes): + obs = test_env.reset() + state = obs[None, ...] # add batch dimension + total_reward = 0.0 + for t in itertools.count(): + log_probs, _ = agent.policy_action(model, state) + probs = onp.exp(onp.array(log_probs, dtype=onp.float32)) + probabilities = probs[0] / probs[0].sum() + action = onp.random.choice(probs.shape[1], p=probabilities) + obs, reward, done, _ = test_env.step(action) + total_reward += reward + next_state = obs[None, ...] if not done else None + state = next_state + if done: + break + return total_reward diff --git a/setup.py b/setup.py index 4a8a1e3762..6846acfa4a 100644 --- a/setup.py +++ b/setup.py @@ -33,8 +33,11 @@ ] tests_require = [ + "atari-py", + "gym", "jaxlib", "ml-collections", + "opencv-python", "pytest", "pytest-cov", "pytest-xdist==1.34.0", # upgrading to 2.0 broke tests, need to investigate