-
Notifications
You must be signed in to change notification settings - Fork 660
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #470 from wrzadkow:rl-example-ppo
PiperOrigin-RevId: 334765232
- Loading branch information
Showing
12 changed files
with
1,053 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Proximal Policy Optimization | ||
|
||
Uses the Proximal Policy Optimization algorithm ([Schulman et al., 2017](https://arxiv.org/abs/1707.06347)) | ||
to learn playing Atari games. | ||
|
||
## Requirements | ||
|
||
This example depends on the `gym`, `opencv-python` and `atari-py` packages | ||
in addition to `jax` and `flax`. | ||
|
||
## Supported setups | ||
|
||
The example should run with other configurations and hardware, but was explicitly | ||
tested on the following: | ||
|
||
| Hardware | Game | Training time | Total frames seen | TensorBoard.dev | | ||
| --- | --- | --- | --- | --- | | ||
| 1x V100 GPU | Qbert | 9h 27m 8s | 40M | [2020-09-30](https://tensorboard.dev/experiment/1pacpbxxRz2di3NIOFkHoA/#scalars) | | ||
|
||
## How to run | ||
|
||
Running `python ppo_main.py` will run the example with default | ||
(hyper)parameters, i.e. for 40M frames on the Pong game. | ||
|
||
By default logging info and checkpoints will be stored in `/tmp/ppo_training` | ||
directory. This can be overriden as follows: | ||
|
||
```python ppo_main.py --logdir=/my_fav_directory``` | ||
|
||
You can also override the default (hyper)parameters, for example | ||
|
||
```python ppo_main.py --config.game=Seaquest --config.total_frames=20000000 --config.decaying_lr_and_clip_param=False --logdir=/tmp/seaquest``` | ||
|
||
will train the model on 20M Seaquest frames with constant (i.e. not linearly | ||
decaying) learning rate and PPO clipping parameter. Checkpoints and tensorboard | ||
files will be saved in `/tmp/seaquest`. | ||
|
||
Unit tests can be run using `python ppo_lib_test.py`. | ||
|
||
## How to run on Google Cloud TPU | ||
|
||
It is also possible to run this code on Google Cloud TPU. For detailed | ||
instructions on the required setup, please refer to the [WMT example readme](https://github.com/google/flax/tree/master/examples/wmt). | ||
|
||
## Owners | ||
|
||
Jonathan Heek @jheek, Wojciech Rzadkowski @wrzadkow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright 2020 The Flax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Agent utilities, incl. choosing the move and running in separate process.""" | ||
|
||
import multiprocessing | ||
import collections | ||
import jax | ||
import numpy as onp | ||
|
||
import env_utils | ||
|
||
@jax.jit | ||
def policy_action(model, state): | ||
"""Forward pass of the network.""" | ||
out = model(state) | ||
return out | ||
|
||
|
||
ExpTuple = collections.namedtuple( | ||
'ExpTuple', ['state', 'action', 'reward', 'value', 'log_prob', 'done']) | ||
|
||
|
||
class RemoteSimulator: | ||
"""Wrap functionality for an agent emulating Atari in a separate process. | ||
An object of this class is created for every agent. | ||
""" | ||
|
||
def __init__(self, game: str): | ||
"""Start the remote process and create Pipe() to communicate with it.""" | ||
parent_conn, child_conn = multiprocessing.Pipe() | ||
self.proc = multiprocessing.Process( | ||
target=rcv_action_send_exp, args=(child_conn, game)) | ||
self.conn = parent_conn | ||
self.proc.start() | ||
|
||
|
||
def rcv_action_send_exp(conn, game: str): | ||
"""Run the remote agents. | ||
Receive action from the main learner, perform one step of simulation and | ||
send back collected experience. | ||
""" | ||
env = env_utils.create_env(game, clip_rewards=True) | ||
while True: | ||
obs = env.reset() | ||
done = False | ||
# Observations fetched from Atari env need additional batch dimension. | ||
state = obs[None, ...] | ||
while not done: | ||
conn.send(state) | ||
action = conn.recv() | ||
obs, reward, done, _ = env.step(action) | ||
next_state = obs[None, ...] if not done else None | ||
experience = (state, action, reward, done) | ||
conn.send(experience) | ||
if done: | ||
break | ||
state = next_state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright 2020 The Flax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Definitions of default hyperparameters.""" | ||
|
||
import ml_collections | ||
|
||
def get_config(): | ||
"""Get the default configuration. | ||
The default hyperparameters originate from PPO paper arXiv:1707.06347 | ||
and openAI baselines 2:: | ||
https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py | ||
""" | ||
config = ml_collections.ConfigDict() | ||
# The Atari game used. | ||
config.game = 'Pong' | ||
# Total number of frames seen during training. | ||
config.total_frames = 40000000 | ||
# The learning rate for the Adam optimizer. | ||
config.learning_rate = 2.5e-4 | ||
# Batch size used in training. | ||
config.batch_size = 256 | ||
# Number of agents playing in parallel. | ||
config.num_agents = 8 | ||
# Number of steps each agent performs in one policy unroll. | ||
config.actor_steps = 128 | ||
# Number of training epochs per each unroll of the policy. | ||
config.num_epochs = 3 | ||
# RL discount parameter. | ||
config.gamma = 0.99 | ||
# Generalized Advantage Estimation parameter. | ||
config.lambda_ = 0.95 | ||
# The PPO clipping parameter used to clamp ratios in loss function. | ||
config.clip_param = 0.1 | ||
# Weight of value function loss in the total loss. | ||
config.vf_coeff = 0.5 | ||
# Weight of entropy bonus in the total loss. | ||
config.entropy_coeff = 0.01 | ||
# Linearly decay learning rate and clipping parameter to zero during | ||
# the training. | ||
config.decaying_lr_and_clip_param = True | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright 2020 The Flax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Utilities for handling the Atari environment.""" | ||
|
||
import collections | ||
import gym | ||
import numpy as onp | ||
|
||
import seed_rl_atari_preprocessing | ||
|
||
class ClipRewardEnv(gym.RewardWrapper): | ||
"""Adapted from OpenAI baselines. | ||
github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py | ||
""" | ||
|
||
def __init__(self, env): | ||
gym.RewardWrapper.__init__(self, env) | ||
|
||
def reward(self, reward): | ||
"""Bin reward to {+1, 0, -1} by its sign.""" | ||
return onp.sign(reward) | ||
|
||
class FrameStack: | ||
"""Implements stacking of `num_frames` last frames of the game. | ||
Wraps an AtariPreprocessing object. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
preproc: seed_rl_atari_preprocessing.AtariPreprocessing, | ||
num_frames: int): | ||
self.preproc = preproc | ||
self.num_frames = num_frames | ||
self.frames = collections.deque(maxlen=num_frames) | ||
|
||
def reset(self): | ||
ob = self.preproc.reset() | ||
for _ in range(self.num_frames): | ||
self.frames.append(ob) | ||
return self._get_array() | ||
|
||
def step(self, action: int): | ||
ob, reward, done, info = self.preproc.step(action) | ||
self.frames.append(ob) | ||
return self._get_array(), reward, done, info | ||
|
||
def _get_array(self): | ||
assert len(self.frames) == self.num_frames | ||
return onp.concatenate(self.frames, axis=-1) | ||
|
||
def create_env(game: str, clip_rewards: bool): | ||
"""Create a FrameStack object that serves as environment for the `game`.""" | ||
env = gym.make(game) | ||
if clip_rewards: | ||
env = ClipRewardEnv(env) # bin rewards to {-1., 0., 1.} | ||
preproc = seed_rl_atari_preprocessing.AtariPreprocessing(env) | ||
stack = FrameStack(preproc, num_frames=4) | ||
return stack | ||
|
||
def get_num_actions(game: str): | ||
"""Get the number of possible actions of a given Atari game. | ||
This determines the number of outputs in the actor part of the | ||
actor-critic model. | ||
""" | ||
env = gym.make(game) | ||
return env.action_space.n |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright 2020 The Flax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Class and functions to define and initialize the actor-critic model.""" | ||
|
||
import numpy as onp | ||
import flax | ||
from flax import nn | ||
import jax.numpy as jnp | ||
|
||
class ActorCritic(flax.nn.Module): | ||
"""Class defining the actor-critic model.""" | ||
|
||
def apply(self, x, num_outputs): | ||
"""Define the convolutional network architecture. | ||
Architecture originates from "Human-level control through deep reinforcement | ||
learning.", Nature 518, no. 7540 (2015): 529-533. | ||
Note that this is different than the one from "Playing atari with deep | ||
reinforcement learning." arxiv.org/abs/1312.5602 (2013) | ||
""" | ||
dtype = jnp.float32 | ||
x = x.astype(dtype) / 255. | ||
x = nn.Conv(x, features=32, kernel_size=(8, 8), | ||
strides=(4, 4), name='conv1', | ||
dtype=dtype) | ||
x = nn.relu(x) | ||
x = nn.Conv(x, features=64, kernel_size=(4, 4), | ||
strides=(2, 2), name='conv2', | ||
dtype=dtype) | ||
x = nn.relu(x) | ||
x = nn.Conv(x, features=64, kernel_size=(3, 3), | ||
strides=(1, 1), name='conv3', | ||
dtype=dtype) | ||
x = nn.relu(x) | ||
x = x.reshape((x.shape[0], -1)) # flatten | ||
x = nn.Dense(x, features=512, name='hidden', dtype=dtype) | ||
x = nn.relu(x) | ||
# Network used to both estimate policy (logits) and expected state value. | ||
# See github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py | ||
logits = nn.Dense(x, features=num_outputs, name='logits', dtype=dtype) | ||
policy_log_probabilities = nn.log_softmax(logits) | ||
value = nn.Dense(x, features=1, name='value', dtype=dtype) | ||
return policy_log_probabilities, value | ||
|
||
def create_model(key: onp.ndarray, num_outputs: int): | ||
input_dims = (1, 84, 84, 4) # (minibatch, height, width, stacked frames) | ||
module = ActorCritic.partial(num_outputs=num_outputs) | ||
_, initial_par = module.init_by_shape(key, [(input_dims, jnp.float32)]) | ||
model = flax.nn.Model(module, initial_par) | ||
return model | ||
|
||
def create_optimizer(model: nn.base.Model, learning_rate: float): | ||
optimizer_def = flax.optim.Adam(learning_rate) | ||
optimizer = optimizer_def.create(model) | ||
return optimizer |
Oops, something went wrong.