-
Notifications
You must be signed in to change notification settings - Fork 660
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use ml_collections for hyperparameter handling
- Loading branch information
Showing
4 changed files
with
79 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
"""Definitions of default hyperparameters.""" | ||
|
||
import ml_collections | ||
|
||
def get_config(): | ||
"""Get the default configuration. | ||
The default hyperparameters originate from PPO paper arXiv:1707.06347 | ||
and openAI baselines 2:: | ||
https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py | ||
""" | ||
config = ml_collections.ConfigDict() | ||
# The Atari game used. | ||
config.game = 'Pong' | ||
# Total number of frames seen during training. | ||
config.total_frames = 40000000 | ||
# The learning rate for the Adam optimizer. | ||
config.learning_rate = 2.5e-4 | ||
# Batch size used in training. | ||
config.batch_size = 256 | ||
# Number of agents playing in parallel. | ||
config.num_agents = 8 | ||
# Number of steps each agent performs in one policy unroll. | ||
config.actor_steps = 128 | ||
# Number of training epochs per each unroll of the policy. | ||
config.num_epochs = 3 | ||
# RL discount parameter. | ||
config.gamma = 0.99 | ||
# Generalized Advantage Estimation parameter. | ||
config.lambda_ = 0.95 | ||
# The PPO clipping parameter used to clamp ratios in loss function. | ||
config.clip_param = 0.1 | ||
# Weight of value function loss in the total loss. | ||
config.vf_coeff = 0.5 | ||
# Weight of entropy bonus in the total loss. | ||
config.entropy_coeff = 0.01 | ||
# Linearly decay learning rate and clipping parameter to zero during | ||
# the training. | ||
config.decaying_lr_and_clip_param = True | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,98 +1,35 @@ | ||
import os | ||
from absl import flags | ||
from absl import app | ||
import jax | ||
import jax.random | ||
from ml_collections import config_flags | ||
|
||
import ppo_lib | ||
import models | ||
import env_utils | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
# Default hyperparameters originate from PPO paper and openAI baselines 2. | ||
# https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py | ||
|
||
flags.DEFINE_float( | ||
'learning_rate', default=2.5e-4, | ||
help=('The learning rate for the Adam optimizer.') | ||
) | ||
|
||
flags.DEFINE_integer( | ||
'batch_size', default=256, | ||
help=('Batch size for training.') | ||
) | ||
|
||
flags.DEFINE_integer( | ||
'num_agents', default=8, | ||
help=('Number of agents playing in parallel.') | ||
) | ||
|
||
flags.DEFINE_integer( | ||
'actor_steps', default=128, | ||
help=('Batch size for training.') | ||
) | ||
|
||
flags.DEFINE_integer( | ||
'num_epochs', default=3, | ||
help=('Number of epochs per each unroll of the policy.') | ||
) | ||
|
||
flags.DEFINE_float( | ||
'gamma', default=0.99, | ||
help=('Discount parameter.') | ||
) | ||
|
||
flags.DEFINE_float( | ||
'lambda_', default=0.95, | ||
help=('Generalized Advantage Estimation parameter.') | ||
) | ||
|
||
flags.DEFINE_float( | ||
'clip_param', default=0.1, | ||
help=('The PPO clipping parameter used to clamp ratios in loss function.') | ||
) | ||
|
||
flags.DEFINE_float( | ||
'vf_coeff', default=0.5, | ||
help=('Weighs value function loss in the total loss.') | ||
) | ||
|
||
flags.DEFINE_float( | ||
'entropy_coeff', default=0.01, | ||
help=('Weighs entropy bonus in the total loss.') | ||
) | ||
|
||
flags.DEFINE_boolean( | ||
'decaying_lr_and_clip_param', default=True, | ||
help=(('Linearly decay learning rate and clipping parameter to zero during ' | ||
'the training.')) | ||
) | ||
|
||
flags.DEFINE_string( | ||
'game', default='Pong', | ||
help=('The Atari game used.') | ||
) | ||
|
||
flags.DEFINE_string( | ||
'logdir', default='/tmp/ppo_training', | ||
help=('Directory to save checkpoints and logging info.') | ||
) | ||
help=('Directory to save checkpoints and logging info.')) | ||
|
||
flags.DEFINE_integer( | ||
'total_frames', default=40000000, | ||
help=('Length of training (total number of frames to be seen).') | ||
) | ||
config_flags.DEFINE_config_file( | ||
'config', os.path.join(os.path.dirname(__file__), 'default_config.py'), | ||
'File path to the default configuration file.') | ||
|
||
def main(argv): | ||
game = FLAGS.game + 'NoFrameskip-v4' | ||
config = FLAGS.config | ||
game = config.game + 'NoFrameskip-v4' | ||
num_actions = env_utils.get_num_actions(game) | ||
print(f'Playing {game} with {num_actions} actions') | ||
key = jax.random.PRNGKey(0) | ||
key, subkey = jax.random.split(key) | ||
model = models.create_model(subkey, num_outputs=num_actions) | ||
optimizer = models.create_optimizer(model, learning_rate=FLAGS.learning_rate) | ||
optimizer = models.create_optimizer(model, learning_rate=config.learning_rate) | ||
del model | ||
optimizer = ppo_lib.train(optimizer, FLAGS) | ||
optimizer = ppo_lib.train(optimizer, config, FLAGS.logdir) | ||
|
||
if __name__ == '__main__': | ||
app.run(main) |