Skip to content

Commit

Permalink
Merge pull request #470 from wrzadkow:rl-example-ppo
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 334765232
  • Loading branch information
Flax Authors committed Oct 1, 2020
2 parents e58dea2 + a4dade8 commit fed1aaf
Show file tree
Hide file tree
Showing 12 changed files with 1,053 additions and 0 deletions.
47 changes: 47 additions & 0 deletions examples/ppo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Proximal Policy Optimization

Uses the Proximal Policy Optimization algorithm ([Schulman et al., 2017](https://arxiv.org/abs/1707.06347))
to learn playing Atari games.

## Requirements

This example depends on the `gym`, `opencv-python` and `atari-py` packages
in addition to `jax` and `flax`.

## Supported setups

The example should run with other configurations and hardware, but was explicitly
tested on the following:

| Hardware | Game | Training time | Total frames seen | TensorBoard.dev |
| --- | --- | --- | --- | --- |
| 1x V100 GPU | Qbert | 9h 27m 8s | 40M | [2020-09-30](https://tensorboard.dev/experiment/1pacpbxxRz2di3NIOFkHoA/#scalars) |

## How to run

Running `python ppo_main.py` will run the example with default
(hyper)parameters, i.e. for 40M frames on the Pong game.

By default logging info and checkpoints will be stored in `/tmp/ppo_training`
directory. This can be overriden as follows:

```python ppo_main.py --logdir=/my_fav_directory```

You can also override the default (hyper)parameters, for example

```python ppo_main.py --config.game=Seaquest --config.total_frames=20000000 --config.decaying_lr_and_clip_param=False --logdir=/tmp/seaquest```

will train the model on 20M Seaquest frames with constant (i.e. not linearly
decaying) learning rate and PPO clipping parameter. Checkpoints and tensorboard
files will be saved in `/tmp/seaquest`.

Unit tests can be run using `python ppo_lib_test.py`.

## How to run on Google Cloud TPU

It is also possible to run this code on Google Cloud TPU. For detailed
instructions on the required setup, please refer to the [WMT example readme](https://github.com/google/flax/tree/master/examples/wmt).

## Owners

Jonathan Heek @jheek, Wojciech Rzadkowski @wrzadkow
71 changes: 71 additions & 0 deletions examples/ppo/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2020 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Agent utilities, incl. choosing the move and running in separate process."""

import multiprocessing
import collections
import jax
import numpy as onp

import env_utils

@jax.jit
def policy_action(model, state):
"""Forward pass of the network."""
out = model(state)
return out


ExpTuple = collections.namedtuple(
'ExpTuple', ['state', 'action', 'reward', 'value', 'log_prob', 'done'])


class RemoteSimulator:
"""Wrap functionality for an agent emulating Atari in a separate process.
An object of this class is created for every agent.
"""

def __init__(self, game: str):
"""Start the remote process and create Pipe() to communicate with it."""
parent_conn, child_conn = multiprocessing.Pipe()
self.proc = multiprocessing.Process(
target=rcv_action_send_exp, args=(child_conn, game))
self.conn = parent_conn
self.proc.start()


def rcv_action_send_exp(conn, game: str):
"""Run the remote agents.
Receive action from the main learner, perform one step of simulation and
send back collected experience.
"""
env = env_utils.create_env(game, clip_rewards=True)
while True:
obs = env.reset()
done = False
# Observations fetched from Atari env need additional batch dimension.
state = obs[None, ...]
while not done:
conn.send(state)
action = conn.recv()
obs, reward, done, _ = env.step(action)
next_state = obs[None, ...] if not done else None
experience = (state, action, reward, done)
conn.send(experience)
if done:
break
state = next_state
54 changes: 54 additions & 0 deletions examples/ppo/default_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2020 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Definitions of default hyperparameters."""

import ml_collections

def get_config():
"""Get the default configuration.
The default hyperparameters originate from PPO paper arXiv:1707.06347
and openAI baselines 2::
https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py
"""
config = ml_collections.ConfigDict()
# The Atari game used.
config.game = 'Pong'
# Total number of frames seen during training.
config.total_frames = 40000000
# The learning rate for the Adam optimizer.
config.learning_rate = 2.5e-4
# Batch size used in training.
config.batch_size = 256
# Number of agents playing in parallel.
config.num_agents = 8
# Number of steps each agent performs in one policy unroll.
config.actor_steps = 128
# Number of training epochs per each unroll of the policy.
config.num_epochs = 3
# RL discount parameter.
config.gamma = 0.99
# Generalized Advantage Estimation parameter.
config.lambda_ = 0.95
# The PPO clipping parameter used to clamp ratios in loss function.
config.clip_param = 0.1
# Weight of value function loss in the total loss.
config.vf_coeff = 0.5
# Weight of entropy bonus in the total loss.
config.entropy_coeff = 0.01
# Linearly decay learning rate and clipping parameter to zero during
# the training.
config.decaying_lr_and_clip_param = True
return config
81 changes: 81 additions & 0 deletions examples/ppo/env_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2020 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for handling the Atari environment."""

import collections
import gym
import numpy as onp

import seed_rl_atari_preprocessing

class ClipRewardEnv(gym.RewardWrapper):
"""Adapted from OpenAI baselines.
github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
"""

def __init__(self, env):
gym.RewardWrapper.__init__(self, env)

def reward(self, reward):
"""Bin reward to {+1, 0, -1} by its sign."""
return onp.sign(reward)

class FrameStack:
"""Implements stacking of `num_frames` last frames of the game.
Wraps an AtariPreprocessing object.
"""

def __init__(
self,
preproc: seed_rl_atari_preprocessing.AtariPreprocessing,
num_frames: int):
self.preproc = preproc
self.num_frames = num_frames
self.frames = collections.deque(maxlen=num_frames)

def reset(self):
ob = self.preproc.reset()
for _ in range(self.num_frames):
self.frames.append(ob)
return self._get_array()

def step(self, action: int):
ob, reward, done, info = self.preproc.step(action)
self.frames.append(ob)
return self._get_array(), reward, done, info

def _get_array(self):
assert len(self.frames) == self.num_frames
return onp.concatenate(self.frames, axis=-1)

def create_env(game: str, clip_rewards: bool):
"""Create a FrameStack object that serves as environment for the `game`."""
env = gym.make(game)
if clip_rewards:
env = ClipRewardEnv(env) # bin rewards to {-1., 0., 1.}
preproc = seed_rl_atari_preprocessing.AtariPreprocessing(env)
stack = FrameStack(preproc, num_frames=4)
return stack

def get_num_actions(game: str):
"""Get the number of possible actions of a given Atari game.
This determines the number of outputs in the actor part of the
actor-critic model.
"""
env = gym.make(game)
return env.action_space.n
67 changes: 67 additions & 0 deletions examples/ppo/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2020 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Class and functions to define and initialize the actor-critic model."""

import numpy as onp
import flax
from flax import nn
import jax.numpy as jnp

class ActorCritic(flax.nn.Module):
"""Class defining the actor-critic model."""

def apply(self, x, num_outputs):
"""Define the convolutional network architecture.
Architecture originates from "Human-level control through deep reinforcement
learning.", Nature 518, no. 7540 (2015): 529-533.
Note that this is different than the one from "Playing atari with deep
reinforcement learning." arxiv.org/abs/1312.5602 (2013)
"""
dtype = jnp.float32
x = x.astype(dtype) / 255.
x = nn.Conv(x, features=32, kernel_size=(8, 8),
strides=(4, 4), name='conv1',
dtype=dtype)
x = nn.relu(x)
x = nn.Conv(x, features=64, kernel_size=(4, 4),
strides=(2, 2), name='conv2',
dtype=dtype)
x = nn.relu(x)
x = nn.Conv(x, features=64, kernel_size=(3, 3),
strides=(1, 1), name='conv3',
dtype=dtype)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(x, features=512, name='hidden', dtype=dtype)
x = nn.relu(x)
# Network used to both estimate policy (logits) and expected state value.
# See github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py
logits = nn.Dense(x, features=num_outputs, name='logits', dtype=dtype)
policy_log_probabilities = nn.log_softmax(logits)
value = nn.Dense(x, features=1, name='value', dtype=dtype)
return policy_log_probabilities, value

def create_model(key: onp.ndarray, num_outputs: int):
input_dims = (1, 84, 84, 4) # (minibatch, height, width, stacked frames)
module = ActorCritic.partial(num_outputs=num_outputs)
_, initial_par = module.init_by_shape(key, [(input_dims, jnp.float32)])
model = flax.nn.Model(module, initial_par)
return model

def create_optimizer(model: nn.base.Model, learning_rate: float):
optimizer_def = flax.optim.Adam(learning_rate)
optimizer = optimizer_def.create(model)
return optimizer
Loading

0 comments on commit fed1aaf

Please sign in to comment.