Skip to content

Commit

Permalink
Use ml_collections for hyperparameter handling
Browse files Browse the repository at this point in the history
  • Loading branch information
wrzadkow committed Sep 30, 2020
1 parent 342786b commit a4dade8
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 95 deletions.
12 changes: 9 additions & 3 deletions examples/ppo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@ tested on the following:
## How to run

Running `python ppo_main.py` will run the example with default
(hyper)parameters, i.e. for 40M frames on the Pong game. You can override the
default parameters, for example
(hyper)parameters, i.e. for 40M frames on the Pong game.

```python ppo_main.py --game=Seaquest --total_frames=20000000 --decaying_lr_and_clip_param=False --logdir=/tmp/seaquest```
By default logging info and checkpoints will be stored in `/tmp/ppo_training`
directory. This can be overriden as follows:

```python ppo_main.py --logdir=/my_fav_directory```

You can also override the default (hyper)parameters, for example

```python ppo_main.py --config.game=Seaquest --config.total_frames=20000000 --config.decaying_lr_and_clip_param=False --logdir=/tmp/seaquest```

will train the model on 20M Seaquest frames with constant (i.e. not linearly
decaying) learning rate and PPO clipping parameter. Checkpoints and tensorboard
Expand Down
40 changes: 40 additions & 0 deletions examples/ppo/default_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Definitions of default hyperparameters."""

import ml_collections

def get_config():
"""Get the default configuration.
The default hyperparameters originate from PPO paper arXiv:1707.06347
and openAI baselines 2::
https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py
"""
config = ml_collections.ConfigDict()
# The Atari game used.
config.game = 'Pong'
# Total number of frames seen during training.
config.total_frames = 40000000
# The learning rate for the Adam optimizer.
config.learning_rate = 2.5e-4
# Batch size used in training.
config.batch_size = 256
# Number of agents playing in parallel.
config.num_agents = 8
# Number of steps each agent performs in one policy unroll.
config.actor_steps = 128
# Number of training epochs per each unroll of the policy.
config.num_epochs = 3
# RL discount parameter.
config.gamma = 0.99
# Generalized Advantage Estimation parameter.
config.lambda_ = 0.95
# The PPO clipping parameter used to clamp ratios in loss function.
config.clip_param = 0.1
# Weight of value function loss in the total loss.
config.vf_coeff = 0.5
# Weight of entropy bonus in the total loss.
config.entropy_coeff = 0.01
# Linearly decay learning rate and clipping parameter to zero during
# the training.
config.decaying_lr_and_clip_param = True
return config
39 changes: 20 additions & 19 deletions examples/ppo/ppo_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import functools
from typing import Tuple, List
from absl import flags
import jax
import jax.random
import jax.numpy as jnp
Expand All @@ -11,6 +10,7 @@
from flax import nn
from flax.metrics import tensorboard
from flax.training import checkpoints
import ml_collections

import agent
import test_episodes
Expand Down Expand Up @@ -201,22 +201,23 @@ def process_experience(

def train(
optimizer: flax.optim.base.Optimizer,
flags_: flags._flagvalues.FlagValues):
config: ml_collections.ConfigDict,
model_dir: str):
"""Main training loop.
Args:
optimizer: optimizer for the actor-critic model
flags_: object holding hyperparameters and the training information
config: object holding hyperparameters and the training information
model_dir: path to dictionary where checkpoints and logging info are stored
Returns:
optimizer: the trained optimizer
"""
game = flags_.game + 'NoFrameskip-v4'
game = config.game + 'NoFrameskip-v4'
simulators = [agent.RemoteSimulator(game)
for _ in range(flags_.num_agents)]
model_dir = '/tmp/ppo_training/'
for _ in range(config.num_agents)]
summary_writer = tensorboard.SummaryWriter(model_dir)
loop_steps = flags_.total_frames // (flags_.num_agents * flags_.actor_steps)
loop_steps = config.total_frames // (config.num_agents * config.actor_steps)
log_frequency = 40
checkpoint_frequency = 500

Expand All @@ -225,26 +226,26 @@ def train(
# Bookkeeping and testing.
if s % log_frequency == 0:
score = test_episodes.policy_test(1, optimizer.target, game)
frames = s * flags_.num_agents * flags_.actor_steps
frames = s * config.num_agents * config.actor_steps
summary_writer.scalar('game_score', score, frames)
print(f'Step {s}:\nframes seen {frames}\nscore {score}\n\n')
if s % checkpoint_frequency == 0:
checkpoints.save_checkpoint(model_dir, optimizer, s)

# Core training code.
alpha = 1. - s/loop_steps if flags_.decaying_lr_and_clip_param else 1.
all_experiences = get_experience(optimizer.target, simulators,
flags_.actor_steps)
alpha = 1. - s/loop_steps if config.decaying_lr_and_clip_param else 1.
all_experiences = get_experience(
optimizer.target, simulators, config.actor_steps)
trajectories = process_experience(
all_experiences, flags_.actor_steps, flags_.num_agents, flags_.gamma,
flags_.lambda_)
lr = flags_.learning_rate * alpha
clip_param = flags_.clip_param * alpha
for e in range(flags_.num_epochs):
all_experiences, config.actor_steps, config.num_agents, config.gamma,
config.lambda_)
lr = config.learning_rate * alpha
clip_param = config.clip_param * alpha
for e in range(config.num_epochs):
permutation = onp.random.permutation(
flags_.num_agents * flags_.actor_steps)
config.num_agents * config.actor_steps)
trajectories = tuple(map(lambda x: x[permutation], trajectories))
optimizer, loss = train_step(
optimizer, trajectories, clip_param, flags_.vf_coeff,
flags_.entropy_coeff, lr, flags_.batch_size)
optimizer, trajectories, clip_param, config.vf_coeff,
config.entropy_coeff, lr, config.batch_size)
return optimizer
83 changes: 10 additions & 73 deletions examples/ppo/ppo_main.py
Original file line number Diff line number Diff line change
@@ -1,98 +1,35 @@
import os
from absl import flags
from absl import app
import jax
import jax.random
from ml_collections import config_flags

import ppo_lib
import models
import env_utils

FLAGS = flags.FLAGS

# Default hyperparameters originate from PPO paper and openAI baselines 2.
# https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py

flags.DEFINE_float(
'learning_rate', default=2.5e-4,
help=('The learning rate for the Adam optimizer.')
)

flags.DEFINE_integer(
'batch_size', default=256,
help=('Batch size for training.')
)

flags.DEFINE_integer(
'num_agents', default=8,
help=('Number of agents playing in parallel.')
)

flags.DEFINE_integer(
'actor_steps', default=128,
help=('Batch size for training.')
)

flags.DEFINE_integer(
'num_epochs', default=3,
help=('Number of epochs per each unroll of the policy.')
)

flags.DEFINE_float(
'gamma', default=0.99,
help=('Discount parameter.')
)

flags.DEFINE_float(
'lambda_', default=0.95,
help=('Generalized Advantage Estimation parameter.')
)

flags.DEFINE_float(
'clip_param', default=0.1,
help=('The PPO clipping parameter used to clamp ratios in loss function.')
)

flags.DEFINE_float(
'vf_coeff', default=0.5,
help=('Weighs value function loss in the total loss.')
)

flags.DEFINE_float(
'entropy_coeff', default=0.01,
help=('Weighs entropy bonus in the total loss.')
)

flags.DEFINE_boolean(
'decaying_lr_and_clip_param', default=True,
help=(('Linearly decay learning rate and clipping parameter to zero during '
'the training.'))
)

flags.DEFINE_string(
'game', default='Pong',
help=('The Atari game used.')
)

flags.DEFINE_string(
'logdir', default='/tmp/ppo_training',
help=('Directory to save checkpoints and logging info.')
)
help=('Directory to save checkpoints and logging info.'))

flags.DEFINE_integer(
'total_frames', default=40000000,
help=('Length of training (total number of frames to be seen).')
)
config_flags.DEFINE_config_file(
'config', os.path.join(os.path.dirname(__file__), 'default_config.py'),
'File path to the default configuration file.')

def main(argv):
game = FLAGS.game + 'NoFrameskip-v4'
config = FLAGS.config
game = config.game + 'NoFrameskip-v4'
num_actions = env_utils.get_num_actions(game)
print(f'Playing {game} with {num_actions} actions')
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
model = models.create_model(subkey, num_outputs=num_actions)
optimizer = models.create_optimizer(model, learning_rate=FLAGS.learning_rate)
optimizer = models.create_optimizer(model, learning_rate=config.learning_rate)
del model
optimizer = ppo_lib.train(optimizer, FLAGS)
optimizer = ppo_lib.train(optimizer, config, FLAGS.logdir)

if __name__ == '__main__':
app.run(main)

0 comments on commit a4dade8

Please sign in to comment.