From 315902bdfb90f863158f3b2a0714d4298ec26e40 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Thu, 1 Oct 2020 08:14:55 +0000 Subject: [PATCH 1/4] Refactor a long statement --- examples/ppo/ppo_lib.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index 1e53303b68..1762465507 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -193,10 +193,9 @@ def process_experience( returns = advantages + values[:-1, :] # After preprocessing, concatenate data from all agents. trajectories = (states, actions, log_probs, returns, advantages) + trajectory_len = num_agents * actor_steps trajectories = tuple(map( - lambda x: onp.reshape( - x, (num_agents * actor_steps,) + x.shape[2:]), - trajectories)) + lambda x: onp.reshape(x, (trajectory_len,) + x.shape[2:]), trajectories)) return trajectories def train( From d2eae5c03983dfa756f0ecf6c9f6e43f22b9d059 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Thu, 1 Oct 2020 08:30:50 +0000 Subject: [PATCH 2/4] Test: use assertEqual and clip rewards when testing them --- examples/ppo/ppo_lib_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/ppo/ppo_lib_test.py b/examples/ppo/ppo_lib_test.py index 65d649e73a..713030fb39 100644 --- a/examples/ppo/ppo_lib_test.py +++ b/examples/ppo/ppo_lib_test.py @@ -50,17 +50,17 @@ def test_creation(self): game = self.choose_random_game() env = env_utils.create_env(game, clip_rewards=True) obs = env.reset() - self.assertTrue(obs.shape == frame_shape) + self.assertEqual(obs.shape, frame_shape) def test_step(self): frame_shape = (84, 84, 4) game = self.choose_random_game() - env = env_utils.create_env(game, clip_rewards=False) + env = env_utils.create_env(game, clip_rewards=True) obs = env.reset() actions = [1, 2, 3, 0] for a in actions: obs, reward, done, info = env.step(a) - self.assertTrue(obs.shape == frame_shape) + self.assertEqual(obs.shape, frame_shape) self.assertTrue(reward <= 1. and reward >= -1.) self.assertTrue(isinstance(done, bool)) self.assertTrue(isinstance(info, dict)) @@ -81,9 +81,9 @@ def test_model(self): test_batch_size, obs_shape = 10, (84, 84, 4) random_input = onp.random.random(size=(test_batch_size,) + obs_shape) log_probs, values = optimizer.target(random_input) - self.assertTrue(values.shape == (test_batch_size, 1)) + self.assertEqual(values.shape, (test_batch_size, 1)) sum_probs = onp.sum(onp.exp(log_probs), axis=1) - self.assertTrue(sum_probs.shape == (test_batch_size, )) + self.assertEqual(sum_probs.shape, (test_batch_size, )) onp_testing.assert_allclose(sum_probs, onp.ones((test_batch_size, )), atol=1e-6) From d444075ee4337ceb2cc21ce071622ba0ed62d745 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Thu, 1 Oct 2020 08:31:56 +0000 Subject: [PATCH 3/4] Compile vectorized code instead of vectorizing compiled code --- examples/ppo/ppo_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index 1762465507..a8c495553a 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -15,8 +15,8 @@ import agent import test_episodes -@functools.partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1) @jax.jit +@functools.partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1) def gae_advantages( rewards: onp.ndarray, terminal_masks: onp.ndarray, From f3a9d03e30a0cf17078c048f6f90be218a738b51 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Thu, 1 Oct 2020 08:32:41 +0000 Subject: [PATCH 4/4] Specify static_argnums with proper int --- examples/ppo/ppo_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index a8c495553a..f662321e36 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -55,7 +55,7 @@ def gae_advantages( advantages = advantages[::-1] return jnp.array(advantages) -@functools.partial(jax.jit, static_argnums=(6)) +@functools.partial(jax.jit, static_argnums=6) def train_step( optimizer: flax.optim.base.Optimizer, trajectories: Tuple,