From d38a671096114d0b8059026a75d6978e07e53d7a Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Fri, 11 Sep 2020 10:43:58 +0000 Subject: [PATCH 01/48] Initial PPO commit --- examples/ppo/agent.py | 11 + examples/ppo/env.py | 36 +++ examples/ppo/main.py | 234 ++++++++++++++++++++ examples/ppo/models.py | 50 +++++ examples/ppo/remote.py | 46 ++++ examples/ppo/seed_rl_atari_preprocessing.py | 212 ++++++++++++++++++ examples/ppo/test_episodes.py | 38 ++++ 7 files changed, 627 insertions(+) create mode 100644 examples/ppo/agent.py create mode 100644 examples/ppo/env.py create mode 100644 examples/ppo/main.py create mode 100644 examples/ppo/models.py create mode 100644 examples/ppo/remote.py create mode 100644 examples/ppo/seed_rl_atari_preprocessing.py create mode 100644 examples/ppo/test_episodes.py diff --git a/examples/ppo/agent.py b/examples/ppo/agent.py new file mode 100644 index 0000000000..300e639bfe --- /dev/null +++ b/examples/ppo/agent.py @@ -0,0 +1,11 @@ +import jax +import numpy as onp + +@jax.jit +def policy_action(model, state): + """Forward pass of the network. + Potentially the random choice of the action from probabilities can be moved + here with additional rng_key parameter.""" + # print("Inference: compile") + out = model(state) + return out \ No newline at end of file diff --git a/examples/ppo/env.py b/examples/ppo/env.py new file mode 100644 index 0000000000..05be2e2aeb --- /dev/null +++ b/examples/ppo/env.py @@ -0,0 +1,36 @@ +import collections +import gym +import numpy as onp + +from seed_rl_atari_preprocessing import AtariPreprocessing + +class FrameStack: + ''' + Class that wraps an AtariPreprocessing object and implements + stacking of `num_frames` last frames of the game + ''' + def __init__(self, preproc: AtariPreprocessing, num_frames : int): + self.preproc = preproc + self.num_frames = num_frames + self.frames = collections.deque(maxlen=num_frames) + + def reset(self): + ob = self.preproc.reset() + for _ in range(self.num_frames): + self.frames.append(ob) + return self._get_array() + + def step(self, action: int): + ob, reward, done, info = self.preproc.step(action) + self.frames.append(ob) + return self._get_array(), reward, done, info + + def _get_array(self): + assert len(self.frames) == self.num_frames + return onp.concatenate(self.frames, axis=-1) + +def create_env(): + env = gym.make("PongNoFrameskip-v4") + preproc = AtariPreprocessing(env) + stack = FrameStack(preproc, num_frames=4) + return stack \ No newline at end of file diff --git a/examples/ppo/main.py b/examples/ppo/main.py new file mode 100644 index 0000000000..b97cad31f5 --- /dev/null +++ b/examples/ppo/main.py @@ -0,0 +1,234 @@ +import jax +import jax.random +import jax.numpy as jnp +import numpy as onp +import flax +import time +from typing import Tuple, List +from queue import Queue +import threading + + +from models import create_model, create_optimizer +from agent import policy_action +from remote import RemoteSimulator +from test_episodes import test + +# @jax.jit +def gae_advantages(rewards, terminal_masks, values, discount, gae_param): + """Use Generalized Advantage Estimation (GAE) to compute advantages + Eqs. (11-12) in PPO paper arXiv: 1707.06347""" + assert rewards.shape[0] + 1 == values.shape[0], ("One more value needed; " + "Eq. (12) in PPO paper requires V(s_{t+1}) to calculate \delta_t") + return_values, gae = [], 0 + for t in reversed(range(len(rewards))): + #masks to set next state value to 0 for terminal states + value_diff = discount * values[t + 1] * terminal_masks[t] - values[t] + delta = rewards[t] + value_diff + # masks[t] to ensure that values before and after a terminal state + # are independent of each other + gae = delta + discount * gae_param * terminal_masks[t] * gae + return_values.insert(0, gae + values[t]) + return onp.array(return_values) #jnp after vectorization + +# @jax.jit +def train_step(optimizer, trn_data, clip_param, vf_coeff, entropy_coeff, + batch_size): + def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): + states, actions, old_log_probs, returns, advantages = minibatch + probs, values = model(states) + log_probs = jnp.log(probs) + entropy = jnp.sum(-probs*log_probs, axis=1).mean() + # from all probs from the forward pass, we need to choose ones + # corresponding to actually taken actions + # log_probs_act_taken = log_probs[jnp.arange(probs.shape[0]), actions]) + # above hits "Indexing mode not yet supported." + log_probs_act_taken = jnp.log(jnp.array( + [probs[i, actions[i]]for i in range(actions.shape[0])])) + ratios = jnp.exp(log_probs_act_taken - old_log_probs) + PG_loss = ratios * advantages + clipped_loss = advantages * jax.lax.clamp(1. - clip_param, ratios, + 1. + clip_param) + PPO_loss = -jnp.mean(jnp.minimum(PG_loss, clipped_loss)) + value_loss = jnp.mean(jnp.square(returns - values)) + return PPO_loss + vf_coeff*value_loss - entropy_coeff*entropy + + iterations = trn_data[0].shape[0] // batch_size + trn_data = jax.tree_map( + lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trn_data) + loss = 0. + for batch in zip(*trn_data): + grad_fn = jax.value_and_grad(loss_fn) + l, grad = grad_fn(optimizer.target, batch, clip_param, vf_coeff, + entropy_coeff) + loss += l + optimizer = optimizer.apply_gradient(grad) + return optimizer, loss + + +def thread_inference( + q1 : Queue, + q2: Queue, + simulators : List[RemoteSimulator], + steps_per_actor : int): + """Worker function for a separate thread used for inference and running + the simulators in order to maximize the GPU/TPU usage. Runs + `steps_per_actor` time steps of the game for each of the `simulators`.""" + + while(True): + optimizer, step = q1.get() + all_experience = [] + for _ in range(steps_per_actor + 1): # +1 to get one more value + # needed for GAE + states = [] + for sim in simulators: + state = sim.conn.recv() + states.append(state) + states = onp.concatenate(states, axis=0) + + # perform inference + # policy_optimizer, step = q1.get() + # print(f"states type {type(states)}") + probs, values = policy_action(optimizer.target, states) + + probs = onp.array(probs) + # print("probs after onp conversion", probs) + + for i, sim in enumerate(simulators): + # probs[i] should sum up to 1, but there are float round errors + # if using jnp.array directly, it required division by probs[i].sum() + # better solutions can be thought of + # issue might be a result of the network using jnp.int 32, , not 64 + probabilities = probs[i] # / probs[i].sum() + action = onp.random.choice(probs.shape[1], p=probabilities) + #in principle, one could avoid sending value and log prob back and forth + sim.conn.send((action, values[i], onp.log(probs[i][action]))) + + # get experience from simulators + experiences = [] + for sim in simulators: + sample = sim.conn.recv() + experiences.append(sample) + all_experience.append(experiences) + + q2.put(all_experience) + + +def train( + optimizer : flax.optim.base.Optimizer, + # target_model : nn.base.Model, + steps_total : int, # maybe rename to frames_total + num_agents : int, + train_device, + inference_device): + + simulators = [RemoteSimulator() for i in range(num_agents)] + q1, q2 = Queue(maxsize=1), Queue(maxsize=1) + inference_thread = threading.Thread(target=thread_inference, + args=(q1, q2, simulators, STEPS_PER_ACTOR), daemon=True) + inference_thread.start() + t1 = time.time() + + for s in range(steps_total // num_agents): + print(f"training loop step {s}") + #bookkeeping and testing + if (s + 1) % (10000 // (num_agents*STEPS_PER_ACTOR)) == 0: + print(f"Frames processed {s*num_agents*STEPS_PER_ACTOR}" + + f"time elapsed {time.time()-t1}") + t1 = time.time() + if (s + 1) % (50000 // (num_agents*STEPS_PER_ACTOR)) == 0: + test(1, optimizer.target, render=False) + + + # send the up-to-date policy model and current step to inference thread + step = s*num_agents + q1.put((optimizer, step)) + + # perform PPO training + # experience is a list of list of tuples, here we preprocess this data to + # get required input for GAE and then for training + # initial version, needs improvement in terms of speed & readability + if s > 0: #avoid training when there's no data yet + obs_shape = (84, 84, 4) + states = onp.zeros((STEPS_PER_ACTOR + 1, NUM_AGENTS) + obs_shape, + dtype=onp.float32) + actions = onp.zeros((STEPS_PER_ACTOR + 1, NUM_AGENTS), dtype=onp.int32) + rewards = onp.zeros((STEPS_PER_ACTOR + 1, NUM_AGENTS), dtype=onp.float32) + values = onp.zeros((STEPS_PER_ACTOR + 1, NUM_AGENTS), dtype=onp.float32) + log_probs = onp.zeros((STEPS_PER_ACTOR + 1, NUM_AGENTS), + dtype=onp.float32) + dones = onp.zeros((STEPS_PER_ACTOR + 1, NUM_AGENTS), dtype=onp.float32) + + # experiences state, action, reward, value, log_prob, done) + for time_step, exp in enumerate(all_experiences): + for agent_id, exp_agent in enumerate(exp): + states[time_step, agent_id, ...] = exp_agent[0] + actions[time_step, agent_id] = exp_agent[1] + rewards[time_step, agent_id] =exp_agent[2] + values[time_step, agent_id] = exp_agent[3] + log_probs[time_step, agent_id] = exp_agent[4] + # dones need to be 0 for terminal states + dones[time_step, agent_id] = float(not exp_agent[5]) + + #calculate returns using GAE (needs to be vectorized instead of foor loop) + returns = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS)) + for i in range(NUM_AGENTS): + returns[:, i] = gae_advantages(rewards[:-1, i], dones[:-1, i], + values[:, i], DISCOUNT, GAE_PARAM) + advantages = returns - values[:-1, :] + + #getting rid of unnecessary data (one more value was needed for GAE) + states = states[:-1, ...].copy() + actions = actions[:-1, ...].copy() + log_probs = log_probs[:-1, ...].copy() + # after all the preprocessing, we discard the information + # about from which agent the data comes by reshaping + trn_data = (states, actions, log_probs, returns, advantages) + trn_data = tuple(map( + lambda x: onp.reshape(x, + (NUM_AGENTS * STEPS_PER_ACTOR , ) + x.shape[2:]), trn_data) + ) + for _ in range(NUM_EPOCHS): #possibly compile this loop inside a jit + permutation = onp.random.permutation(NUM_AGENTS * STEPS_PER_ACTOR) + trn_data = tuple(map(lambda x: x[permutation], trn_data)) + optimizer, _ = train_step(optimizer, trn_data, CLIP_PARAM, VF_COEFF, + ENTROPY_COEFF, BATCH_SIZE) + #end of PPO training + + #collect new data from the inference thread + all_experiences = q2.get() + + return None + + +STEPS_PER_ACTOR = 128 +NUM_AGENTS = 8 +NUM_EPOCHS = 3 +BATCH_SIZE = 32 * 8 + +DISCOUNT = 0.99 #usually denoted with \gamma +GAE_PARAM = 0.95 #usually denoted with \lambda + +VF_COEFF = 1 #weighs value function loss in total loss +ENTROPY_COEFF = 0.01 # weighs entropy bonus in the total loss + +LR = 2.5e-4 + +CLIP_PARAM = 0.1 + +key = jax.random.PRNGKey(0) +key, subkey = jax.random.split(key) +model = create_model(subkey) +optimizer = create_optimizer(model, learning_rate=LR) +del model + +def main(): + num_agents = NUM_AGENTS + total_frames = 4000000 + train_device = jax.devices()[0] + inference_device = jax.devices()[1] + jax.device_put(optimizer.target, device=train_device) + train(optimizer, total_frames, num_agents, train_device, inference_device) + +if __name__ == '__main__': + main() diff --git a/examples/ppo/models.py b/examples/ppo/models.py new file mode 100644 index 0000000000..aea8985b17 --- /dev/null +++ b/examples/ppo/models.py @@ -0,0 +1,50 @@ +import flax +from flax import nn +import jax.numpy as jnp + +class ActorCritic(flax.nn.Module): + ''' + Architecture from "Human-level control through deep reinforcement learning." + Nature 518, no. 7540 (2015): 529-533. + Note that this is different than the one from "Playing atari with deep + reinforcement learning." arxiv.org/abs/1312.5602 (2013) + ''' + def apply(self, x): + x = x.astype(jnp.float32) / 255. + dtype = jnp.float32 + x = nn.Conv(x, features=32, kernel_size=(8, 8), + strides=(4, 4), name='conv1', + dtype=dtype) + # x = nn.relu(x) + x = jnp.maximum(0, x) + x = nn.Conv(x, features=64, kernel_size=(4, 4), + strides=(2, 2), name='conv2', + dtype=dtype) + # x = nn.relu(x) + x = jnp.maximum(0, x) + x = nn.Conv(x, features=64, kernel_size=(3, 3), + strides=(1, 1), name='conv3', + dtype=dtype) + # x = nn.relu(x) + x = jnp.maximum(0, x) + x = x.reshape((x.shape[0], -1)) # flatten + x = nn.Dense(x, features=512, name='hidden', dtype=dtype) + # x = nn.relu(x) + x = jnp.maximum(0, x) + # network used to both estimate policy (logits) and expected state value + # see github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py + logits = nn.Dense(x, features=4, name='logits', dtype=dtype) + policy_probabilities = nn.softmax(logits) + value = nn.Dense(x, features=1, name='value', dtype=dtype) + return policy_probabilities, value + +def create_model(key): + input_dims = (1, 84, 84, 4) #(minibatch, height, width, stacked frames) + _, initial_par = ActorCritic.init_by_shape(key, [(input_dims, jnp.float32)]) + model = flax.nn.Model(ActorCritic, initial_par) + return model + +def create_optimizer(model, learning_rate): + optimizer_def = flax.optim.Adam(learning_rate) + optimizer = optimizer_def.create(model) + return optimizer diff --git a/examples/ppo/remote.py b/examples/ppo/remote.py new file mode 100644 index 0000000000..dfbf1ca52c --- /dev/null +++ b/examples/ppo/remote.py @@ -0,0 +1,46 @@ +import multiprocessing +import numpy as onp +from env import create_env + + +class RemoteSimulator: + """ + Class that wraps basic functionality needed for an agent + emulating Atari in a separate process. + An object of this class is created for every agent. + """ + def __init__(self): + parent_conn, child_conn = multiprocessing.Pipe() + self.proc = multiprocessing.Process( + target=rcv_action_send_exp, args=(child_conn,)) + self.conn = parent_conn + self.proc.start() + + +def rcv_action_send_exp(conn): + """ + Function running on remote agents. Receives action from + the main learner, performs one step of simulation and + sends back collected experience. + """ + env = create_env() + while True: + obs = env.reset() + done = False + state = get_state(obs) + while not done: + conn.send(state) + action, value, log_prob = conn.recv() + obs, reward, done, _ = env.step(action) + next_state = get_state(obs) if not done else None + # maybe a dictionary instead of a tuple would be better? + experience = (state, action, reward, value, log_prob, done) + conn.send(experience) + if done: + break + state = next_state + + +def get_state(observation): + state = onp.array(observation) + return state[None, ...] diff --git a/examples/ppo/seed_rl_atari_preprocessing.py b/examples/ppo/seed_rl_atari_preprocessing.py new file mode 100644 index 0000000000..2e5982d760 --- /dev/null +++ b/examples/ppo/seed_rl_atari_preprocessing.py @@ -0,0 +1,212 @@ +# coding=utf-8 +# Copyright 2019 The SEED Authors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A class implementing minimal Atari 2600 preprocessing. +Adapted from Dopamine. +""" + +from gym.spaces.box import Box +import numpy as np + +import cv2 + + +class AtariPreprocessing(object): + """A class implementing image preprocessing for Atari 2600 agents. + Specifically, this provides the following subset from the JAIR paper + (Bellemare et al., 2013) and Nature DQN paper (Mnih et al., 2015): + * Frame skipping (defaults to 4). + * Terminal signal when a life is lost (off by default). + * Grayscale and max-pooling of the last two frames. + * Downsample the screen to a square image (defaults to 84x84). + More generally, this class follows the preprocessing guidelines set down in + Machado et al. (2018), "Revisiting the Arcade Learning Environment: + Evaluation Protocols and Open Problems for General Agents". + It also provides random starting no-ops, which are used in the Rainbow, Apex + and R2D2 papers. + """ + + def __init__(self, environment, frame_skip=4, terminal_on_life_loss=False, + screen_size=84, max_random_noops=0): + """Constructor for an Atari 2600 preprocessor. + Args: + environment: Gym environment whose observations are preprocessed. + frame_skip: int, the frequency at which the agent experiences the game. + terminal_on_life_loss: bool, If True, the step() method returns + is_terminal=True whenever a life is lost. See Mnih et al. 2015. + screen_size: int, size of a resized Atari 2600 frame. + max_random_noops: int, maximum number of no-ops to apply at the beginning + of each episode to reduce determinism. These no-ops are applied at a + low-level, before frame skipping. + Raises: + ValueError: if frame_skip or screen_size are not strictly positive. + """ + if frame_skip <= 0: + raise ValueError('Frame skip should be strictly positive, got {}'. + format(frame_skip)) + if screen_size <= 0: + raise ValueError('Target screen size should be strictly positive, got {}'. + format(screen_size)) + + self.environment = environment + self.terminal_on_life_loss = terminal_on_life_loss + self.frame_skip = frame_skip + self.screen_size = screen_size + self.max_random_noops = max_random_noops + + obs_dims = self.environment.observation_space + # Stores temporary observations used for pooling over two successive + # frames. + self.screen_buffer = [ + np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8), + np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8) + ] + + self.game_over = False + self.lives = 0 # Will need to be set by reset(). + + @property + def observation_space(self): + # Return the observation space adjusted to match the shape of the processed + # observations. + return Box(low=0, high=255, shape=(self.screen_size, self.screen_size, 1), + dtype=np.uint8) + + @property + def action_space(self): + return self.environment.action_space + + @property + def reward_range(self): + return self.environment.reward_range + + @property + def metadata(self): + return self.environment.metadata + + def close(self): + return self.environment.close() + + def apply_random_noops(self): + """Steps self.environment with random no-ops.""" + if self.max_random_noops <= 0: + return + # Other no-ops implementations actually always do at least 1 no-op. We + # follow them. + no_ops = self.environment.np_random.randint(1, self.max_random_noops + 1) + for _ in range(no_ops): + _, _, game_over, _ = self.environment.step(0) + if game_over: + self.environment.reset() + + def reset(self): + """Resets the environment. + Returns: + observation: numpy array, the initial observation emitted by the + environment. + """ + self.environment.reset() + self.apply_random_noops() + + self.lives = self.environment.ale.lives() + self._fetch_grayscale_observation(self.screen_buffer[0]) + self.screen_buffer[1].fill(0) + return self._pool_and_resize() + + def render(self, mode): + """Renders the current screen, before preprocessing. + This calls the Gym API's render() method. + Args: + mode: Mode argument for the environment's render() method. + Valid values (str) are: + 'rgb_array': returns the raw ALE image. + 'human': renders to display via the Gym renderer. + Returns: + if mode='rgb_array': numpy array, the most recent screen. + if mode='human': bool, whether the rendering was successful. + """ + return self.environment.render(mode) + + def step(self, action): + """Applies the given action in the environment. + Remarks: + * If a terminal state (from life loss or episode end) is reached, this may + execute fewer than self.frame_skip steps in the environment. + * Furthermore, in this case the returned observation may not contain valid + image data and should be ignored. + Args: + action: The action to be executed. + Returns: + observation: numpy array, the observation following the action. + reward: float, the reward following the action. + is_terminal: bool, whether the environment has reached a terminal state. + This is true when a life is lost and terminal_on_life_loss, or when the + episode is over. + info: Gym API's info data structure. + """ + accumulated_reward = 0. + + for time_step in range(self.frame_skip): + # We bypass the Gym observation altogether and directly fetch the + # grayscale image from the ALE. This is a little faster. + _, reward, game_over, info = self.environment.step(action) + accumulated_reward += reward + + if self.terminal_on_life_loss: + new_lives = self.environment.ale.lives() + is_terminal = game_over or new_lives < self.lives + self.lives = new_lives + else: + is_terminal = game_over + + if is_terminal: + break + # We max-pool over the last two frames, in grayscale. + elif time_step >= self.frame_skip - 2: + t = time_step - (self.frame_skip - 2) + self._fetch_grayscale_observation(self.screen_buffer[t]) + + # Pool the last two observations. + observation = self._pool_and_resize() + + self.game_over = game_over + return observation, accumulated_reward, is_terminal, info + + def _fetch_grayscale_observation(self, output): + """Returns the current observation in grayscale. + The returned observation is stored in 'output'. + Args: + output: numpy array, screen buffer to hold the returned observation. + Returns: + observation: numpy array, the current observation in grayscale. + """ + self.environment.ale.getScreenGrayscale(output) + return output + + def _pool_and_resize(self): + """Transforms two frames into a Nature DQN observation. + For efficiency, the transformation is done in-place in self.screen_buffer. + Returns: + transformed_screen: numpy array, pooled, resized screen. + """ + # Pool if there are enough screens to do so. + if self.frame_skip > 1: + np.maximum(self.screen_buffer[0], self.screen_buffer[1], + out=self.screen_buffer[0]) + + transformed_image = cv2.resize(self.screen_buffer[0], + (self.screen_size, self.screen_size), + interpolation=cv2.INTER_LINEAR) + int_image = np.asarray(transformed_image, dtype=np.uint8) + return np.expand_dims(int_image, axis=2) diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py new file mode 100644 index 0000000000..064acf5913 --- /dev/null +++ b/examples/ppo/test_episodes.py @@ -0,0 +1,38 @@ +import time +import itertools +import gym +import flax +import numpy as onp + +from env import create_env +from remote import get_state +from agent import policy_action + +def test(n_episodes: int, model: flax.nn.base.Model, render: bool = False): + test_env = create_env() + if render: + test_env = gym.wrappers.Monitor( + test_env, "./rendered/" + "ddqn_pong_recording", force=True) + for e in range(n_episodes): + obs = test_env.reset() + state = get_state(obs) + total_reward = 0.0 + for t in itertools.count(): + probs, _ = policy_action(model, state) + probs = onp.array(probs, dtype=onp.float64) + probabilities = probs[0] / probs[0].sum() + action = onp.random.choice(probs.shape[1], p=probabilities) + obs, reward, done, _ = test_env.step(action) + total_reward += reward + if render: + test_env.render() + time.sleep(0.01) + if not done: + next_state = get_state(obs) + else: + next_state = None + state = next_state + if done: + print(f"Finished Episode {e} with reward {total_reward}") + break + del test_env From c0ff3efd4d095de76fdebb12e41035bdd318ce05 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Fri, 11 Sep 2020 13:02:33 +0000 Subject: [PATCH 02/48] Use jax.nn.one_hot instead of list comprehension for speed --- examples/ppo/main.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index b97cad31f5..e82f06955f 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -39,12 +39,11 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): probs, values = model(states) log_probs = jnp.log(probs) entropy = jnp.sum(-probs*log_probs, axis=1).mean() - # from all probs from the forward pass, we need to choose ones - # corresponding to actually taken actions + # we need to choose probs corresponding to actually taken actions # log_probs_act_taken = log_probs[jnp.arange(probs.shape[0]), actions]) - # above hits "Indexing mode not yet supported." - log_probs_act_taken = jnp.log(jnp.array( - [probs[i, actions[i]]for i in range(actions.shape[0])])) + # above hits "Indexing mode not yet supported.", hence one hot solution + act_one_hot = jax.nn.one_hot(actions, num_classes=probs.shape[1]) + log_probs_act_taken = jnp.log(jnp.sum(act_one_hot*probs, axis=1)) ratios = jnp.exp(log_probs_act_taken - old_log_probs) PG_loss = ratios * advantages clipped_loss = advantages * jax.lax.clamp(1. - clip_param, ratios, From f576a76abab06a1358bb10ed7bb10e17dbded37e Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Fri, 11 Sep 2020 13:30:04 +0000 Subject: [PATCH 03/48] Clarity: calculate only advantages in gae_advantages() --- examples/ppo/main.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index e82f06955f..e1ec57638d 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -20,16 +20,17 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): Eqs. (11-12) in PPO paper arXiv: 1707.06347""" assert rewards.shape[0] + 1 == values.shape[0], ("One more value needed; " "Eq. (12) in PPO paper requires V(s_{t+1}) to calculate \delta_t") - return_values, gae = [], 0 + advantages, gae = [], 0 + # Key observation: A_{t} = \delta_t + \gamma*\lambda*A_{t+1} for t in reversed(range(len(rewards))): - #masks to set next state value to 0 for terminal states + # masks to set next state value to 0 for terminal states value_diff = discount * values[t + 1] * terminal_masks[t] - values[t] delta = rewards[t] + value_diff # masks[t] to ensure that values before and after a terminal state # are independent of each other gae = delta + discount * gae_param * terminal_masks[t] * gae - return_values.insert(0, gae + values[t]) - return onp.array(return_values) #jnp after vectorization + advantages.insert(0, gae) + return onp.array(advantages) #jnp after vectorization # @jax.jit def train_step(optimizer, trn_data, clip_param, vf_coeff, entropy_coeff, @@ -169,13 +170,12 @@ def train( # dones need to be 0 for terminal states dones[time_step, agent_id] = float(not exp_agent[5]) - #calculate returns using GAE (needs to be vectorized instead of foor loop) - returns = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS)) + #calculate advantages w. GAE (needs to be vectorized instead of foor loop) + advantages = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS)) for i in range(NUM_AGENTS): - returns[:, i] = gae_advantages(rewards[:-1, i], dones[:-1, i], + advantages[:, i] = gae_advantages(rewards[:-1, i], dones[:-1, i], values[:, i], DISCOUNT, GAE_PARAM) - advantages = returns - values[:-1, :] - + returns = advantages + values[:-1, :] #getting rid of unnecessary data (one more value was needed for GAE) states = states[:-1, ...].copy() actions = actions[:-1, ...].copy() From 11bc5938fe62f36607644e24afbca4566100b6f6 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Fri, 11 Sep 2020 14:56:32 +0000 Subject: [PATCH 04/48] jit-compile training step --- examples/ppo/main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index e1ec57638d..2cf58aeaf4 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -32,9 +32,8 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): advantages.insert(0, gae) return onp.array(advantages) #jnp after vectorization -# @jax.jit -def train_step(optimizer, trn_data, clip_param, vf_coeff, entropy_coeff, - batch_size): +@jax.jit +def train_step(optimizer, trn_data, clip_param, vf_coeff, entropy_coeff): def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): states, actions, old_log_probs, returns, advantages = minibatch probs, values = model(states) @@ -53,6 +52,7 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): value_loss = jnp.mean(jnp.square(returns - values)) return PPO_loss + vf_coeff*value_loss - entropy_coeff*entropy + batch_size = BATCH_SIZE iterations = trn_data[0].shape[0] // batch_size trn_data = jax.tree_map( lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trn_data) @@ -191,7 +191,7 @@ def train( permutation = onp.random.permutation(NUM_AGENTS * STEPS_PER_ACTOR) trn_data = tuple(map(lambda x: x[permutation], trn_data)) optimizer, _ = train_step(optimizer, trn_data, CLIP_PARAM, VF_COEFF, - ENTROPY_COEFF, BATCH_SIZE) + ENTROPY_COEFF) #end of PPO training #collect new data from the inference thread From 5feeec71666b4b76993e8e0a55787336dd5fb989 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Mon, 14 Sep 2020 10:20:40 +0000 Subject: [PATCH 05/48] Clarity: get rid of most [:-1] indexing --- examples/ppo/main.py | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 2cf58aeaf4..85323224dd 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -29,7 +29,7 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): # masks[t] to ensure that values before and after a terminal state # are independent of each other gae = delta + discount * gae_param * terminal_masks[t] * gae - advantages.insert(0, gae) + advantages = [gae] + advantages return onp.array(advantages) #jnp after vectorization @jax.jit @@ -150,36 +150,38 @@ def train( # initial version, needs improvement in terms of speed & readability if s > 0: #avoid training when there's no data yet obs_shape = (84, 84, 4) - states = onp.zeros((STEPS_PER_ACTOR + 1, NUM_AGENTS) + obs_shape, + states = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS) + obs_shape, dtype=onp.float32) - actions = onp.zeros((STEPS_PER_ACTOR + 1, NUM_AGENTS), dtype=onp.int32) - rewards = onp.zeros((STEPS_PER_ACTOR + 1, NUM_AGENTS), dtype=onp.float32) + actions = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS), dtype=onp.int32) + rewards = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS), dtype=onp.float32) values = onp.zeros((STEPS_PER_ACTOR + 1, NUM_AGENTS), dtype=onp.float32) - log_probs = onp.zeros((STEPS_PER_ACTOR + 1, NUM_AGENTS), + log_probs = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS), dtype=onp.float32) - dones = onp.zeros((STEPS_PER_ACTOR + 1, NUM_AGENTS), dtype=onp.float32) + dones = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS), dtype=onp.float32) # experiences state, action, reward, value, log_prob, done) - for time_step, exp in enumerate(all_experiences): - for agent_id, exp_agent in enumerate(exp): - states[time_step, agent_id, ...] = exp_agent[0] - actions[time_step, agent_id] = exp_agent[1] - rewards[time_step, agent_id] =exp_agent[2] - values[time_step, agent_id] = exp_agent[3] - log_probs[time_step, agent_id] = exp_agent[4] + # for time_step, exp in enumerate(all_experiences): + for t in range(len(all_experiences) - 1): #last only for next_values + for agent_id, exp_agent in enumerate(all_experiences[t]): + states[t, agent_id, ...] = exp_agent[0] + actions[t, agent_id] = exp_agent[1] + rewards[t, agent_id] =exp_agent[2] + values[t, agent_id] = exp_agent[3] + log_probs[t, agent_id] = exp_agent[4] # dones need to be 0 for terminal states - dones[time_step, agent_id] = float(not exp_agent[5]) - + dones[t, agent_id] = float(not exp_agent[5]) + for a in range(num_agents): + values[-1, a] = all_experiences[-1][a][3] #calculate advantages w. GAE (needs to be vectorized instead of foor loop) advantages = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS)) for i in range(NUM_AGENTS): - advantages[:, i] = gae_advantages(rewards[:-1, i], dones[:-1, i], + advantages[:, i] = gae_advantages(rewards[:, i], dones[:, i], values[:, i], DISCOUNT, GAE_PARAM) returns = advantages + values[:-1, :] #getting rid of unnecessary data (one more value was needed for GAE) - states = states[:-1, ...].copy() - actions = actions[:-1, ...].copy() - log_probs = log_probs[:-1, ...].copy() + # states = states[:-1, ...].copy() + # actions = actions[:-1, ...].copy() + # log_probs = log_probs[:-1, ...].copy() # after all the preprocessing, we discard the information # about from which agent the data comes by reshaping trn_data = (states, actions, log_probs, returns, advantages) From 8be6677cdfbd33431cfe13754619c38b5486f50f Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Mon, 14 Sep 2020 11:52:19 +0000 Subject: [PATCH 06/48] jit & vmap Generalized Advantage Estimation --- examples/ppo/main.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 85323224dd..4b1ed892f3 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -4,6 +4,7 @@ import numpy as onp import flax import time +from functools import partial from typing import Tuple, List from queue import Queue import threading @@ -14,7 +15,8 @@ from remote import RemoteSimulator from test_episodes import test -# @jax.jit +@partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1) +@jax.jit def gae_advantages(rewards, terminal_masks, values, discount, gae_param): """Use Generalized Advantage Estimation (GAE) to compute advantages Eqs. (11-12) in PPO paper arXiv: 1707.06347""" @@ -30,7 +32,7 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): # are independent of each other gae = delta + discount * gae_param * terminal_masks[t] * gae advantages = [gae] + advantages - return onp.array(advantages) #jnp after vectorization + return jnp.array(advantages) @jax.jit def train_step(optimizer, trn_data, clip_param, vf_coeff, entropy_coeff): @@ -172,11 +174,8 @@ def train( dones[t, agent_id] = float(not exp_agent[5]) for a in range(num_agents): values[-1, a] = all_experiences[-1][a][3] - #calculate advantages w. GAE (needs to be vectorized instead of foor loop) - advantages = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS)) - for i in range(NUM_AGENTS): - advantages[:, i] = gae_advantages(rewards[:, i], dones[:, i], - values[:, i], DISCOUNT, GAE_PARAM) + # calculate advantages w. GAE + advantages = gae_advantages(rewards, dones, values, DISCOUNT, GAE_PARAM) returns = advantages + values[:-1, :] #getting rid of unnecessary data (one more value was needed for GAE) # states = states[:-1, ...].copy() From 670978fad016c5c97b997f5b7f2f7bf9c70ebec1 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Mon, 14 Sep 2020 12:37:49 +0000 Subject: [PATCH 07/48] Add advantage normalization --- examples/ppo/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 4b1ed892f3..53bceab0e5 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -47,6 +47,8 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): act_one_hot = jax.nn.one_hot(actions, num_classes=probs.shape[1]) log_probs_act_taken = jnp.log(jnp.sum(act_one_hot*probs, axis=1)) ratios = jnp.exp(log_probs_act_taken - old_log_probs) + # adv. normalization (following the OpenAI baselines) + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) PG_loss = ratios * advantages clipped_loss = advantages * jax.lax.clamp(1. - clip_param, ratios, 1. + clip_param) From 3414100881184126bbcd220583e7e2a17c0902bd Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Mon, 14 Sep 2020 13:46:57 +0000 Subject: [PATCH 08/48] Small code cleanup --- examples/ppo/agent.py | 4 ++-- examples/ppo/env.py | 3 +-- examples/ppo/main.py | 18 ++++++------------ examples/ppo/remote.py | 6 ++---- 4 files changed, 11 insertions(+), 20 deletions(-) diff --git a/examples/ppo/agent.py b/examples/ppo/agent.py index 300e639bfe..df64825a96 100644 --- a/examples/ppo/agent.py +++ b/examples/ppo/agent.py @@ -5,7 +5,7 @@ def policy_action(model, state): """Forward pass of the network. Potentially the random choice of the action from probabilities can be moved - here with additional rng_key parameter.""" - # print("Inference: compile") + here with additional rng_key parameter. + """ out = model(state) return out \ No newline at end of file diff --git a/examples/ppo/env.py b/examples/ppo/env.py index 05be2e2aeb..f3ff04a1db 100644 --- a/examples/ppo/env.py +++ b/examples/ppo/env.py @@ -5,8 +5,7 @@ from seed_rl_atari_preprocessing import AtariPreprocessing class FrameStack: - ''' - Class that wraps an AtariPreprocessing object and implements + '''Class that wraps an AtariPreprocessing object and implements stacking of `num_frames` last frames of the game ''' def __init__(self, preproc: AtariPreprocessing, num_frames : int): diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 53bceab0e5..348fcb2a95 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -19,11 +19,12 @@ @jax.jit def gae_advantages(rewards, terminal_masks, values, discount, gae_param): """Use Generalized Advantage Estimation (GAE) to compute advantages - Eqs. (11-12) in PPO paper arXiv: 1707.06347""" + Eqs. (11-12) in PPO paper arXiv: 1707.06347. + Uses key observation that A_{t} = \delta_t + \gamma*\lambda*A_{t+1}. + """ assert rewards.shape[0] + 1 == values.shape[0], ("One more value needed; " "Eq. (12) in PPO paper requires V(s_{t+1}) to calculate \delta_t") advantages, gae = [], 0 - # Key observation: A_{t} = \delta_t + \gamma*\lambda*A_{t+1} for t in reversed(range(len(rewards))): # masks to set next state value to 0 for terminal states value_diff = discount * values[t + 1] * terminal_masks[t] - values[t] @@ -77,7 +78,8 @@ def thread_inference( steps_per_actor : int): """Worker function for a separate thread used for inference and running the simulators in order to maximize the GPU/TPU usage. Runs - `steps_per_actor` time steps of the game for each of the `simulators`.""" + `steps_per_actor` time steps of the game for each of the `simulators`. + """ while(True): optimizer, step = q1.get() @@ -92,7 +94,6 @@ def thread_inference( # perform inference # policy_optimizer, step = q1.get() - # print(f"states type {type(states)}") probs, values = policy_action(optimizer.target, states) probs = onp.array(probs) @@ -163,8 +164,6 @@ def train( dtype=onp.float32) dones = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS), dtype=onp.float32) - # experiences state, action, reward, value, log_prob, done) - # for time_step, exp in enumerate(all_experiences): for t in range(len(all_experiences) - 1): #last only for next_values for agent_id, exp_agent in enumerate(all_experiences[t]): states[t, agent_id, ...] = exp_agent[0] @@ -179,12 +178,7 @@ def train( # calculate advantages w. GAE advantages = gae_advantages(rewards, dones, values, DISCOUNT, GAE_PARAM) returns = advantages + values[:-1, :] - #getting rid of unnecessary data (one more value was needed for GAE) - # states = states[:-1, ...].copy() - # actions = actions[:-1, ...].copy() - # log_probs = log_probs[:-1, ...].copy() - # after all the preprocessing, we discard the information - # about from which agent the data comes by reshaping + # after preprocessing, concatenate data from all agents trn_data = (states, actions, log_probs, returns, advantages) trn_data = tuple(map( lambda x: onp.reshape(x, diff --git a/examples/ppo/remote.py b/examples/ppo/remote.py index dfbf1ca52c..ddad8868a3 100644 --- a/examples/ppo/remote.py +++ b/examples/ppo/remote.py @@ -4,8 +4,7 @@ class RemoteSimulator: - """ - Class that wraps basic functionality needed for an agent + """Class that wraps basic functionality needed for an agent emulating Atari in a separate process. An object of this class is created for every agent. """ @@ -18,8 +17,7 @@ def __init__(self): def rcv_action_send_exp(conn): - """ - Function running on remote agents. Receives action from + """Function running on remote agents. Receives action from the main learner, performs one step of simulation and sends back collected experience. """ From f40b0491b0ed37a812b736143e57f3a6e0283ee8 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Mon, 14 Sep 2020 16:00:32 +0000 Subject: [PATCH 09/48] Add some asserts & debug info logging --- examples/ppo/main.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 348fcb2a95..59771ff8c5 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -39,6 +39,9 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): def train_step(optimizer, trn_data, clip_param, vf_coeff, entropy_coeff): def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): states, actions, old_log_probs, returns, advantages = minibatch + shapes = list(map(lambda x : x.shape, minibatch)) + assert(shapes[0] == (BATCH_SIZE, 84, 84, 4)) + assert(all(s == (BATCH_SIZE,) for s in shapes[1:])) probs, values = model(states) log_probs = jnp.log(probs) entropy = jnp.sum(-probs*log_probs, axis=1).mean() @@ -121,8 +124,7 @@ def thread_inference( def train( optimizer : flax.optim.base.Optimizer, - # target_model : nn.base.Model, - steps_total : int, # maybe rename to frames_total + steps_total : int, num_agents : int, train_device, inference_device): @@ -138,7 +140,7 @@ def train( print(f"training loop step {s}") #bookkeeping and testing if (s + 1) % (10000 // (num_agents*STEPS_PER_ACTOR)) == 0: - print(f"Frames processed {s*num_agents*STEPS_PER_ACTOR}" + + print(f"Frames processed {s*num_agents*STEPS_PER_ACTOR}, " + f"time elapsed {time.time()-t1}") t1 = time.time() if (s + 1) % (50000 // (num_agents*STEPS_PER_ACTOR)) == 0: @@ -164,6 +166,8 @@ def train( dtype=onp.float32) dones = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS), dtype=onp.float32) + assert(len(all_experiences) == STEPS_PER_ACTOR + 1) + assert(len(all_experiences[0]) == NUM_AGENTS) for t in range(len(all_experiences) - 1): #last only for next_values for agent_id, exp_agent in enumerate(all_experiences[t]): states[t, agent_id, ...] = exp_agent[0] @@ -180,15 +184,21 @@ def train( returns = advantages + values[:-1, :] # after preprocessing, concatenate data from all agents trn_data = (states, actions, log_probs, returns, advantages) + trn_data = tuple(map( lambda x: onp.reshape(x, (NUM_AGENTS * STEPS_PER_ACTOR , ) + x.shape[2:]), trn_data) ) - for _ in range(NUM_EPOCHS): #possibly compile this loop inside a jit + print(f"Step {s}: rewards variance {rewards.var()}") + for e in range(NUM_EPOCHS): #possibly compile this loop inside a jit + shapes = list(map(lambda x : x.shape, trn_data)) + assert(shapes[0] == (NUM_AGENTS * STEPS_PER_ACTOR, 84, 84, 4)) + assert(all(s == (NUM_AGENTS * STEPS_PER_ACTOR,) for s in shapes[1:])) permutation = onp.random.permutation(NUM_AGENTS * STEPS_PER_ACTOR) trn_data = tuple(map(lambda x: x[permutation], trn_data)) - optimizer, _ = train_step(optimizer, trn_data, CLIP_PARAM, VF_COEFF, + optimizer, loss = train_step(optimizer, trn_data, CLIP_PARAM, VF_COEFF, ENTROPY_COEFF) + print(f"Step {s} epoch {e} loss {loss}") #end of PPO training #collect new data from the inference thread From 2bd52d823c9ae039dd0cc0adbabe7a47f003138e Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Tue, 15 Sep 2020 08:13:09 +0000 Subject: [PATCH 10/48] Add unit tests --- examples/ppo/unit_tests.py | 67 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 examples/ppo/unit_tests.py diff --git a/examples/ppo/unit_tests.py b/examples/ppo/unit_tests.py new file mode 100644 index 0000000000..f2408aa701 --- /dev/null +++ b/examples/ppo/unit_tests.py @@ -0,0 +1,67 @@ + +import jax +import flax +from flax import nn +import numpy as onp + +import numpy.testing as onp_testing +from absl.testing import absltest + +#test GAE +from main import gae_advantages +class TestGAE(absltest.TestCase): + def test_gae_random(self): + # create random data, simulating 4 parallel envs and 20 time_steps + envs, steps = 10, 100 + rewards = onp.random.choice([-1., 0., 1.], size=(steps, envs), + p=[0.01, 0.98, 0.01]) + terminal_masks = onp.ones(shape=(steps, envs), dtype=onp.float64) + values = onp.random.random(size=(steps + 1, envs)) + discount = 0.99 + gae_param = 0.95 + adv = gae_advantages(rewards, terminal_masks, values, discount, gae_param) + self.assertEqual(adv.shape, (steps, envs)) + # test the property A_{t} = \delta_t + \gamma*\lambda*A_{t+1} + # for each agent separately + for e in range(envs): + for t in range(steps-1): + delta = rewards[t, e] + discount * values[t+1, e] - values[t, e] + lhs = adv[t, e] + rhs = delta + discount * gae_param * adv[t+1, e] + onp_testing.assert_almost_equal(lhs, rhs) + +#test environment and preprocessing +from remote import RemoteSimulator, rcv_action_send_exp +from env import create_env +class TestEnvironmentPreprocessing(absltest.TestCase): + def test_creation(self): + frame_shape = (84, 84, 4) + env = create_env() + obs = env.reset() + self.assertTrue(obs.shape == frame_shape) + + def test_step(self): + frame_shape = (84, 84, 4) + env = create_env() + obs = env.reset() + actions = [1, 2, 3, 0] + for a in actions: + obs, reward, done, info = env.step(a) + self.assertTrue(obs.shape == frame_shape) + self.assertTrue(reward <= 1. and reward >= -1.) + self.assertTrue(isinstance(done, bool)) + self.assertTrue(isinstance(info, dict)) + +#test creation of the model and optimizer +from models import create_model, create_optimizer +class TestCreation(absltest.TestCase): + def test_create(self): + key = jax.random.PRNGKey(0) + key, subkey = jax.random.split(key) + policy_model = create_model(subkey) + policy_optimizer = create_optimizer(policy_model, learning_rate=1e-3) + self.assertTrue(isinstance(policy_model, nn.base.Model)) + self.assertTrue(isinstance(policy_optimizer, flax.optim.base.Optimizer)) + +if __name__ == '__main__': + absltest.main() \ No newline at end of file From b943afc74044dcfcb0da8af63a8c8582600e6d5b Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Tue, 15 Sep 2020 17:24:19 +0000 Subject: [PATCH 11/48] Add more debugging info --- examples/ppo/main.py | 41 ++++++++++++++++++++++++++--------- examples/ppo/test_episodes.py | 2 +- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 59771ff8c5..671776f7ed 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -24,7 +24,7 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): """ assert rewards.shape[0] + 1 == values.shape[0], ("One more value needed; " "Eq. (12) in PPO paper requires V(s_{t+1}) to calculate \delta_t") - advantages, gae = [], 0 + advantages, gae = [], 0. for t in reversed(range(len(rewards))): # masks to set next state value to 0 for terminal states value_diff = discount * values[t + 1] * terminal_masks[t] - values[t] @@ -71,7 +71,8 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): entropy_coeff) loss += l optimizer = optimizer.apply_gradient(grad) - return optimizer, loss + grad_norm = sum(jnp.square(g).sum() for g in jax.tree_leaves(grad)) + return optimizer, loss, grad_norm def thread_inference( @@ -137,10 +138,10 @@ def train( t1 = time.time() for s in range(steps_total // num_agents): - print(f"training loop step {s}") + print(f"\n training loop step {s}") #bookkeeping and testing if (s + 1) % (10000 // (num_agents*STEPS_PER_ACTOR)) == 0: - print(f"Frames processed {s*num_agents*STEPS_PER_ACTOR}, " + + print(f" Frames processed {s*num_agents*STEPS_PER_ACTOR}, " + f"time elapsed {time.time()-t1}") t1 = time.time() if (s + 1) % (50000 // (num_agents*STEPS_PER_ACTOR)) == 0: @@ -190,15 +191,17 @@ def train( (NUM_AGENTS * STEPS_PER_ACTOR , ) + x.shape[2:]), trn_data) ) print(f"Step {s}: rewards variance {rewards.var()}") + dr = dones.ravel() + print(f"fraction of terminal states {1.-(dr.sum()/dr.shape[0])}") for e in range(NUM_EPOCHS): #possibly compile this loop inside a jit shapes = list(map(lambda x : x.shape, trn_data)) assert(shapes[0] == (NUM_AGENTS * STEPS_PER_ACTOR, 84, 84, 4)) assert(all(s == (NUM_AGENTS * STEPS_PER_ACTOR,) for s in shapes[1:])) permutation = onp.random.permutation(NUM_AGENTS * STEPS_PER_ACTOR) trn_data = tuple(map(lambda x: x[permutation], trn_data)) - optimizer, loss = train_step(optimizer, trn_data, CLIP_PARAM, VF_COEFF, - ENTROPY_COEFF) - print(f"Step {s} epoch {e} loss {loss}") + optimizer, loss, last_iter_grad_norm = train_step(optimizer, trn_data, + CLIP_PARAM, VF_COEFF, ENTROPY_COEFF) + print(f"Step {s} epoch {e} loss {loss} grad norm {last_iter_grad_norm}") #end of PPO training #collect new data from the inference thread @@ -206,7 +209,8 @@ def train( return None - +# PPO paper and openAI baselines 2 +# https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py STEPS_PER_ACTOR = 128 NUM_AGENTS = 8 NUM_EPOCHS = 3 @@ -215,13 +219,30 @@ def train( DISCOUNT = 0.99 #usually denoted with \gamma GAE_PARAM = 0.95 #usually denoted with \lambda -VF_COEFF = 1 #weighs value function loss in total loss +VF_COEFF = 0.5 #weighs value function loss in total loss ENTROPY_COEFF = 0.01 # weighs entropy bonus in the total loss LR = 2.5e-4 CLIP_PARAM = 0.1 +# openAI baselines 1 +# https://github.com/openai/baselines/blob/master/baselines/ppo1/run_atari.py +# STEPS_PER_ACTOR = 256 +# NUM_AGENTS = 8 +# NUM_EPOCHS = 4 +# BATCH_SIZE = 64 + +# DISCOUNT = 0.99 #usually denoted with \gamma +# GAE_PARAM = 0.95 #usually denoted with \lambda + +# VF_COEFF = 1. #weighs value function loss in total loss +# ENTROPY_COEFF = 0.01 # weighs entropy bonus in the total loss + +# LR = 1e-3 + +# CLIP_PARAM = 0.2 + key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) model = create_model(subkey) @@ -233,7 +254,7 @@ def main(): total_frames = 4000000 train_device = jax.devices()[0] inference_device = jax.devices()[1] - jax.device_put(optimizer.target, device=train_device) + # jax.device_put(optimizer.target, device=train_device) train(optimizer, total_frames, num_agents, train_device, inference_device) if __name__ == '__main__': diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py index 064acf5913..3cbf1a6261 100644 --- a/examples/ppo/test_episodes.py +++ b/examples/ppo/test_episodes.py @@ -33,6 +33,6 @@ def test(n_episodes: int, model: flax.nn.base.Model, render: bool = False): next_state = None state = next_state if done: - print(f"Finished Episode {e} with reward {total_reward}") + print(f"------> TEST FINISHED: finished Episode {e} with reward {total_reward}") break del test_env From b0543a9526cc6de85d7a3bb0d043712d539cbf7f Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Wed, 16 Sep 2020 08:17:06 +0000 Subject: [PATCH 12/48] Add forward pass tests --- examples/ppo/unit_tests.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/examples/ppo/unit_tests.py b/examples/ppo/unit_tests.py index f2408aa701..20a51131f4 100644 --- a/examples/ppo/unit_tests.py +++ b/examples/ppo/unit_tests.py @@ -52,16 +52,26 @@ def test_step(self): self.assertTrue(isinstance(done, bool)) self.assertTrue(isinstance(info, dict)) -#test creation of the model and optimizer +#test the model (creation and forward pass) from models import create_model, create_optimizer -class TestCreation(absltest.TestCase): - def test_create(self): +class TestModel(absltest.TestCase): + def test_model(self): key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) - policy_model = create_model(subkey) - policy_optimizer = create_optimizer(policy_model, learning_rate=1e-3) - self.assertTrue(isinstance(policy_model, nn.base.Model)) - self.assertTrue(isinstance(policy_optimizer, flax.optim.base.Optimizer)) + model = create_model(subkey) + optimizer = create_optimizer(model, learning_rate=1e-3) + self.assertTrue(isinstance(model, nn.base.Model)) + self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer)) + test_batch_size, obs_shape = 10, (84, 84, 4) + random_input = onp.random.random(size=(test_batch_size,) + obs_shape) + probs, values = optimizer.target(random_input) + self.assertTrue(values.shape == (test_batch_size, 1)) + sum_probs = onp.sum(probs, axis=1) + self.assertTrue(sum_probs.shape == (test_batch_size, )) + onp_testing.assert_almost_equal(sum_probs, onp.ones((test_batch_size, ))) + + + if __name__ == '__main__': absltest.main() \ No newline at end of file From 6eedf84e07ba5ad01587b79b3e735ebc10308e15 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Wed, 16 Sep 2020 09:34:48 +0000 Subject: [PATCH 13/48] Explicitly mention values shape being (batch,1), not (batch, ) (no influence on results) --- examples/ppo/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 671776f7ed..a6f74ef423 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -43,6 +43,7 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): assert(shapes[0] == (BATCH_SIZE, 84, 84, 4)) assert(all(s == (BATCH_SIZE,) for s in shapes[1:])) probs, values = model(states) + values = values[:, 0] # convert shapes: (batch, 1) to (batch, ) log_probs = jnp.log(probs) entropy = jnp.sum(-probs*log_probs, axis=1).mean() # we need to choose probs corresponding to actually taken actions @@ -111,7 +112,7 @@ def thread_inference( probabilities = probs[i] # / probs[i].sum() action = onp.random.choice(probs.shape[1], p=probabilities) #in principle, one could avoid sending value and log prob back and forth - sim.conn.send((action, values[i], onp.log(probs[i][action]))) + sim.conn.send((action, values[i, 0], onp.log(probs[i][action]))) # get experience from simulators experiences = [] From 04763aaca0a1f7a0403b2f96251e260a7da3b6c3 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Wed, 16 Sep 2020 11:02:48 +0000 Subject: [PATCH 14/48] Add more asserts, test more frequently --- examples/ppo/main.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index a6f74ef423..2fa55d3ae9 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -57,8 +57,10 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): PG_loss = ratios * advantages clipped_loss = advantages * jax.lax.clamp(1. - clip_param, ratios, 1. + clip_param) - PPO_loss = -jnp.mean(jnp.minimum(PG_loss, clipped_loss)) - value_loss = jnp.mean(jnp.square(returns - values)) + assert(PG_loss.shape == clipped_loss.shape) + PPO_loss = -jnp.mean(jnp.minimum(PG_loss, clipped_loss), axis=0) + assert(values.shape == returns.shape) + value_loss = jnp.mean(jnp.square(returns - values), axis=0) return PPO_loss + vf_coeff*value_loss - entropy_coeff*entropy batch_size = BATCH_SIZE @@ -145,7 +147,7 @@ def train( print(f" Frames processed {s*num_agents*STEPS_PER_ACTOR}, " + f"time elapsed {time.time()-t1}") t1 = time.time() - if (s + 1) % (50000 // (num_agents*STEPS_PER_ACTOR)) == 0: + if (s + 1) % (20000 // (num_agents*STEPS_PER_ACTOR)) == 0: test(1, optimizer.target, render=False) @@ -184,9 +186,9 @@ def train( # calculate advantages w. GAE advantages = gae_advantages(rewards, dones, values, DISCOUNT, GAE_PARAM) returns = advantages + values[:-1, :] + assert(returns.shape == advantages.shape == (STEPS_PER_ACTOR, NUM_AGENTS)) # after preprocessing, concatenate data from all agents trn_data = (states, actions, log_probs, returns, advantages) - trn_data = tuple(map( lambda x: onp.reshape(x, (NUM_AGENTS * STEPS_PER_ACTOR , ) + x.shape[2:]), trn_data) @@ -244,17 +246,18 @@ def train( # CLIP_PARAM = 0.2 -key = jax.random.PRNGKey(0) -key, subkey = jax.random.split(key) -model = create_model(subkey) -optimizer = create_optimizer(model, learning_rate=LR) -del model + def main(): num_agents = NUM_AGENTS total_frames = 4000000 train_device = jax.devices()[0] inference_device = jax.devices()[1] + key = jax.random.PRNGKey(0) + key, subkey = jax.random.split(key) + model = create_model(subkey) + optimizer = create_optimizer(model, learning_rate=LR) + del model # jax.device_put(optimizer.target, device=train_device) train(optimizer, total_frames, num_agents, train_device, inference_device) From be01451d21b0e809ff59a320719b5b4aaebf4d7b Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Wed, 16 Sep 2020 15:04:59 +0000 Subject: [PATCH 15/48] Use log_probs from the start --- examples/ppo/main.py | 16 ++++++---------- examples/ppo/models.py | 4 ++-- examples/ppo/test_episodes.py | 8 +++++--- examples/ppo/unit_tests.py | 4 ++-- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 2fa55d3ae9..e009f767da 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -42,15 +42,11 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): shapes = list(map(lambda x : x.shape, minibatch)) assert(shapes[0] == (BATCH_SIZE, 84, 84, 4)) assert(all(s == (BATCH_SIZE,) for s in shapes[1:])) - probs, values = model(states) + log_probs, values = model(states) values = values[:, 0] # convert shapes: (batch, 1) to (batch, ) - log_probs = jnp.log(probs) + probs = jnp.exp(log_probs) entropy = jnp.sum(-probs*log_probs, axis=1).mean() - # we need to choose probs corresponding to actually taken actions - # log_probs_act_taken = log_probs[jnp.arange(probs.shape[0]), actions]) - # above hits "Indexing mode not yet supported.", hence one hot solution - act_one_hot = jax.nn.one_hot(actions, num_classes=probs.shape[1]) - log_probs_act_taken = jnp.log(jnp.sum(act_one_hot*probs, axis=1)) + log_probs_act_taken = jax.vmap(lambda lp, a: lp[a])(log_probs, actions) ratios = jnp.exp(log_probs_act_taken - old_log_probs) # adv. normalization (following the OpenAI baselines) advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) @@ -101,9 +97,9 @@ def thread_inference( # perform inference # policy_optimizer, step = q1.get() - probs, values = policy_action(optimizer.target, states) + log_probs, values = policy_action(optimizer.target, states) - probs = onp.array(probs) + probs = onp.exp(onp.array(log_probs)) # print("probs after onp conversion", probs) for i, sim in enumerate(simulators): @@ -114,7 +110,7 @@ def thread_inference( probabilities = probs[i] # / probs[i].sum() action = onp.random.choice(probs.shape[1], p=probabilities) #in principle, one could avoid sending value and log prob back and forth - sim.conn.send((action, values[i, 0], onp.log(probs[i][action]))) + sim.conn.send((action, values[i, 0], log_probs[i][action])) # get experience from simulators experiences = [] diff --git a/examples/ppo/models.py b/examples/ppo/models.py index aea8985b17..385c701b38 100644 --- a/examples/ppo/models.py +++ b/examples/ppo/models.py @@ -34,9 +34,9 @@ def apply(self, x): # network used to both estimate policy (logits) and expected state value # see github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py logits = nn.Dense(x, features=4, name='logits', dtype=dtype) - policy_probabilities = nn.softmax(logits) + policy_log_probabilities = nn.log_softmax(logits) value = nn.Dense(x, features=1, name='value', dtype=dtype) - return policy_probabilities, value + return policy_log_probabilities, value def create_model(key): input_dims = (1, 84, 84, 4) #(minibatch, height, width, stacked frames) diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py index 3cbf1a6261..56f29f0ea8 100644 --- a/examples/ppo/test_episodes.py +++ b/examples/ppo/test_episodes.py @@ -18,10 +18,12 @@ def test(n_episodes: int, model: flax.nn.base.Model, render: bool = False): state = get_state(obs) total_reward = 0.0 for t in itertools.count(): - probs, _ = policy_action(model, state) - probs = onp.array(probs, dtype=onp.float64) + log_probs, _ = policy_action(model, state) + probs = onp.exp(onp.array(log_probs, dtype=onp.float32)) probabilities = probs[0] / probs[0].sum() + print(f"probabilities {probabilities}") action = onp.random.choice(probs.shape[1], p=probabilities) + print(f"action {action}") obs, reward, done, _ = test_env.step(action) total_reward += reward if render: @@ -33,6 +35,6 @@ def test(n_episodes: int, model: flax.nn.base.Model, render: bool = False): next_state = None state = next_state if done: - print(f"------> TEST FINISHED: finished Episode {e} with reward {total_reward}") + print(f"------> TEST FINISHED: finished Episode {e} with reward {total_reward} in {t} steps") break del test_env diff --git a/examples/ppo/unit_tests.py b/examples/ppo/unit_tests.py index 20a51131f4..c3699aff79 100644 --- a/examples/ppo/unit_tests.py +++ b/examples/ppo/unit_tests.py @@ -64,9 +64,9 @@ def test_model(self): self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer)) test_batch_size, obs_shape = 10, (84, 84, 4) random_input = onp.random.random(size=(test_batch_size,) + obs_shape) - probs, values = optimizer.target(random_input) + log_probs, values = optimizer.target(random_input) self.assertTrue(values.shape == (test_batch_size, 1)) - sum_probs = onp.sum(probs, axis=1) + sum_probs = onp.sum(onp.exp(log_probs), axis=1) self.assertTrue(sum_probs.shape == (test_batch_size, )) onp_testing.assert_almost_equal(sum_probs, onp.ones((test_batch_size, ))) From a99baac631c3586a239cd9853e4b96d6aa14c52d Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Thu, 17 Sep 2020 07:16:05 +0000 Subject: [PATCH 16/48] Thread sync: wait for experience before starting the training --- examples/ppo/main.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index e009f767da..1b9a4a7743 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -155,7 +155,8 @@ def train( # experience is a list of list of tuples, here we preprocess this data to # get required input for GAE and then for training # initial version, needs improvement in terms of speed & readability - if s > 0: #avoid training when there's no data yet + all_experiences = q2.get() + if s >= 0: #avoid training when there's no data yet obs_shape = (84, 84, 4) states = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS) + obs_shape, dtype=onp.float32) @@ -203,9 +204,6 @@ def train( print(f"Step {s} epoch {e} loss {loss} grad norm {last_iter_grad_norm}") #end of PPO training - #collect new data from the inference thread - all_experiences = q2.get() - return None # PPO paper and openAI baselines 2 From c06e8d777d3ff375d64a71167bc8a68e0999bd71 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Thu, 17 Sep 2020 10:34:41 +0000 Subject: [PATCH 17/48] Reduce amount of information printed when testing --- examples/ppo/test_episodes.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py index 56f29f0ea8..c909a2c8f4 100644 --- a/examples/ppo/test_episodes.py +++ b/examples/ppo/test_episodes.py @@ -13,6 +13,7 @@ def test(n_episodes: int, model: flax.nn.base.Model, render: bool = False): if render: test_env = gym.wrappers.Monitor( test_env, "./rendered/" + "ddqn_pong_recording", force=True) + all_probabilities = [] for e in range(n_episodes): obs = test_env.reset() state = get_state(obs) @@ -21,9 +22,8 @@ def test(n_episodes: int, model: flax.nn.base.Model, render: bool = False): log_probs, _ = policy_action(model, state) probs = onp.exp(onp.array(log_probs, dtype=onp.float32)) probabilities = probs[0] / probs[0].sum() - print(f"probabilities {probabilities}") + all_probabilities.append(probabilities) action = onp.random.choice(probs.shape[1], p=probabilities) - print(f"action {action}") obs, reward, done, _ = test_env.step(action) total_reward += reward if render: @@ -35,6 +35,10 @@ def test(n_episodes: int, model: flax.nn.base.Model, render: bool = False): next_state = None state = next_state if done: - print(f"------> TEST FINISHED: finished Episode {e} with reward {total_reward} in {t} steps") + all_probabilities = onp.stack(all_probabilities, axis=0) + print(f"all_probabilities shape {all_probabilities.shape}") + vars = onp.var(all_probabilities, axis=0) + print(f"------> TEST FINISHED: reward {total_reward} in {t} steps") + print(f"Variance of probabilities across encuntered states {vars}") break del test_env From 21a3540b6d5e80f6fbaaa80aa300e1e17e74427b Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Thu, 17 Sep 2020 11:42:04 +0000 Subject: [PATCH 18/48] Clarity: use namedtuple instead of tuple --- examples/ppo/main.py | 14 +++++++------- examples/ppo/remote.py | 6 ++++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 1b9a4a7743..b057406130 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -171,15 +171,15 @@ def train( assert(len(all_experiences[0]) == NUM_AGENTS) for t in range(len(all_experiences) - 1): #last only for next_values for agent_id, exp_agent in enumerate(all_experiences[t]): - states[t, agent_id, ...] = exp_agent[0] - actions[t, agent_id] = exp_agent[1] - rewards[t, agent_id] =exp_agent[2] - values[t, agent_id] = exp_agent[3] - log_probs[t, agent_id] = exp_agent[4] + states[t, agent_id, ...] = exp_agent.state + actions[t, agent_id] = exp_agent.action + rewards[t, agent_id] = exp_agent.reward + values[t, agent_id] = exp_agent.value + log_probs[t, agent_id] = exp_agent.log_prob # dones need to be 0 for terminal states - dones[t, agent_id] = float(not exp_agent[5]) + dones[t, agent_id] = float(not exp_agent.done) for a in range(num_agents): - values[-1, a] = all_experiences[-1][a][3] + values[-1, a] = all_experiences[-1][a].value # calculate advantages w. GAE advantages = gae_advantages(rewards, dones, values, DISCOUNT, GAE_PARAM) returns = advantages + values[:-1, :] diff --git a/examples/ppo/remote.py b/examples/ppo/remote.py index ddad8868a3..1ae5e19b6a 100644 --- a/examples/ppo/remote.py +++ b/examples/ppo/remote.py @@ -1,7 +1,10 @@ import multiprocessing import numpy as onp +from collections import namedtuple from env import create_env +exp_tuple = namedtuple('exp_tuple', + ['state', 'action', 'reward', 'value', 'log_prob', 'done']) class RemoteSimulator: """Class that wraps basic functionality needed for an agent @@ -31,8 +34,7 @@ def rcv_action_send_exp(conn): action, value, log_prob = conn.recv() obs, reward, done, _ = env.step(action) next_state = get_state(obs) if not done else None - # maybe a dictionary instead of a tuple would be better? - experience = (state, action, reward, value, log_prob, done) + experience = exp_tuple(state, action, reward, value, log_prob, done) conn.send(experience) if done: break From c18dd9ddcd9472dff71fa57f764e51fba863d2d4 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Thu, 17 Sep 2020 12:35:04 +0000 Subject: [PATCH 19/48] Add README --- examples/ppo/README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 examples/ppo/README.md diff --git a/examples/ppo/README.md b/examples/ppo/README.md new file mode 100644 index 0000000000..3f4af121ed --- /dev/null +++ b/examples/ppo/README.md @@ -0,0 +1,12 @@ +# Proximal Policy Optimization + +Uses the Proximal Policy Optimization algorithm ([Schulman et al., 2017](https://arxiv.org/abs/1707.06347)) to learn playing Atari games. + +## Requirements + +* This example depends on the `gym` and `atari-py` packages in addition to `jax` and `flax`. + +## How to run + +`python main.py` runs the main training loop. +Unit tests can be run using `python unit_tests.py` \ No newline at end of file From d9ad5be8eb56f1f7cc8a072d459eb245895f4fea Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Thu, 17 Sep 2020 13:38:19 +0000 Subject: [PATCH 20/48] Enhance docstrings --- examples/ppo/env.py | 4 +++- examples/ppo/main.py | 29 ++++++++++++++++++++++++++++- examples/ppo/remote.py | 3 +++ examples/ppo/test_episodes.py | 8 ++++++++ 4 files changed, 42 insertions(+), 2 deletions(-) diff --git a/examples/ppo/env.py b/examples/ppo/env.py index f3ff04a1db..35dec26b9a 100644 --- a/examples/ppo/env.py +++ b/examples/ppo/env.py @@ -6,7 +6,7 @@ class FrameStack: '''Class that wraps an AtariPreprocessing object and implements - stacking of `num_frames` last frames of the game + stacking of `num_frames` last frames of the game. ''' def __init__(self, preproc: AtariPreprocessing, num_frames : int): self.preproc = preproc @@ -29,6 +29,8 @@ def _get_array(self): return onp.concatenate(self.frames, axis=-1) def create_env(): + '''Create a FrameStack object that serves as environment for the game. + ''' env = gym.make("PongNoFrameskip-v4") preproc = AtariPreprocessing(env) stack = FrameStack(preproc, num_frames=4) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index b057406130..d646155946 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -37,6 +37,24 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): @jax.jit def train_step(optimizer, trn_data, clip_param, vf_coeff, entropy_coeff): + """Compilable train step. Runs an entire epoch of training (i.e. the loop over + minibatches within an epoch is included here for performance reasons. + Args: + optimizer: optimizer for the actor-critic model + trn_data: Tuple of the following five elements forming the experience: + states: shape (steps_per_agent*num_agents, 84, 84, 4) + actions: shape (steps_per_agent*num_agents, 84, 84, 4) + old_log_probs: shape (steps_per_agent*num_agentss, ) + returns: shape (steps_per_agent*num_agentss, ) + advantages: (steps_per_agent*num_agents, ) + clip_param: the PPO clipping parameter used to clamp ratios in loss function + vf_coeff: weighs value function loss in total loss + entropy_coeff: weighs entropy bonus in the total loss + Returns: + optimizer: new optimizer after the parameters update + loss: loss summed over training steps + grad_norm: gradient norm from last step (summed over parameters) + """ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): states, actions, old_log_probs, returns, advantages = minibatch shapes = list(map(lambda x : x.shape, minibatch)) @@ -128,7 +146,16 @@ def train( num_agents : int, train_device, inference_device): - + """Main training loop. + Args: + optimizer: optimizer for the actor-critic model + steps total: total number of frames (env steps) to train on + num_agents: number of separate processes with agents running the envs + train_device : device used for training + inference_device : device used for inference + Returns: + None + """ simulators = [RemoteSimulator() for i in range(num_agents)] q1, q2 = Queue(maxsize=1), Queue(maxsize=1) inference_thread = threading.Thread(target=thread_inference, diff --git a/examples/ppo/remote.py b/examples/ppo/remote.py index 1ae5e19b6a..a6888be88e 100644 --- a/examples/ppo/remote.py +++ b/examples/ppo/remote.py @@ -42,5 +42,8 @@ def rcv_action_send_exp(conn): def get_state(observation): + """Covert observation from Atari environment into a NumPy array and add + a batch dimension. + """ state = onp.array(observation) return state[None, ...] diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py index c909a2c8f4..fe49725463 100644 --- a/examples/ppo/test_episodes.py +++ b/examples/ppo/test_episodes.py @@ -9,6 +9,14 @@ from agent import policy_action def test(n_episodes: int, model: flax.nn.base.Model, render: bool = False): + """Perform a test of the policy in Atari environment. + Args: + n_episodes: number of full Atari episodes to test on + model: the actor-critic model being tested + render: whether to render the test environment + Returns: + None + """ test_env = create_env() if render: test_env = gym.wrappers.Monitor( From d0ff2ae7f24f97b223e99ca5d9d57da483a31a47 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Fri, 18 Sep 2020 08:29:45 +0000 Subject: [PATCH 21/48] Allow more flexible game choice (don't hardcode game-pecific features) --- examples/ppo/env.py | 28 ++++++++++++++++++++++++---- examples/ppo/main.py | 19 ++++++++++++++----- examples/ppo/models.py | 11 ++++++----- examples/ppo/remote.py | 8 ++++---- examples/ppo/test_episodes.py | 8 ++++++-- examples/ppo/unit_tests.py | 20 +++++++++++++++----- 6 files changed, 69 insertions(+), 25 deletions(-) diff --git a/examples/ppo/env.py b/examples/ppo/env.py index 35dec26b9a..2e0f062c82 100644 --- a/examples/ppo/env.py +++ b/examples/ppo/env.py @@ -4,6 +4,17 @@ from seed_rl_atari_preprocessing import AtariPreprocessing +class ClipRewardEnv(gym.RewardWrapper): + """This class is adatpted from OpenAI baselines + github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py + """ + def __init__(self, env): + gym.RewardWrapper.__init__(self, env) + + def reward(self, reward): + """Bin reward to {+1, 0, -1} by its sign.""" + return onp.sign(reward) + class FrameStack: '''Class that wraps an AtariPreprocessing object and implements stacking of `num_frames` last frames of the game. @@ -28,10 +39,19 @@ def _get_array(self): assert len(self.frames) == self.num_frames return onp.concatenate(self.frames, axis=-1) -def create_env(): - '''Create a FrameStack object that serves as environment for the game. +def create_env(game : str): + '''Create a FrameStack object that serves as environment for the `game`. ''' - env = gym.make("PongNoFrameskip-v4") + env = gym.make(game) + env = ClipRewardEnv(env) # bin rewards to {-1., 0., 1.} preproc = AtariPreprocessing(env) stack = FrameStack(preproc, num_frames=4) - return stack \ No newline at end of file + return stack + +def get_num_actions(game : str): + """Get the number of possible actions of a given Atari game. This determines + the number of outputs in the actor part of the actor-critic model. + """ + env = gym.make(game) + return env.action_space.n + diff --git a/examples/ppo/main.py b/examples/ppo/main.py index d646155946..6623f23b8a 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -14,6 +14,7 @@ from agent import policy_action from remote import RemoteSimulator from test_episodes import test +from env import get_num_actions @partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1) @jax.jit @@ -142,6 +143,7 @@ def thread_inference( def train( optimizer : flax.optim.base.Optimizer, + game : str, steps_total : int, num_agents : int, train_device, @@ -149,6 +151,7 @@ def train( """Main training loop. Args: optimizer: optimizer for the actor-critic model + game: string specifying the Atari game from Gym package steps total: total number of frames (env steps) to train on num_agents: number of separate processes with agents running the envs train_device : device used for training @@ -156,7 +159,7 @@ def train( Returns: None """ - simulators = [RemoteSimulator() for i in range(num_agents)] + simulators = [RemoteSimulator(game) for i in range(num_agents)] q1, q2 = Queue(maxsize=1), Queue(maxsize=1) inference_thread = threading.Thread(target=thread_inference, args=(q1, q2, simulators, STEPS_PER_ACTOR), daemon=True) @@ -170,8 +173,8 @@ def train( print(f" Frames processed {s*num_agents*STEPS_PER_ACTOR}, " + f"time elapsed {time.time()-t1}") t1 = time.time() - if (s + 1) % (20000 // (num_agents*STEPS_PER_ACTOR)) == 0: - test(1, optimizer.target, render=False) + if (s + 1) % (2000 // (num_agents*STEPS_PER_ACTOR)) == 0: + test(1, optimizer.target, game, render=False) # send the up-to-date policy model and current step to inference thread @@ -208,6 +211,7 @@ def train( for a in range(num_agents): values[-1, a] = all_experiences[-1][a].value # calculate advantages w. GAE + print(f"nonzero rewards {rewards[onp.nonzero(rewards)]}") advantages = gae_advantages(rewards, dones, values, DISCOUNT, GAE_PARAM) returns = advantages + values[:-1, :] assert(returns.shape == advantages.shape == (STEPS_PER_ACTOR, NUM_AGENTS)) @@ -270,17 +274,22 @@ def train( def main(): + game = "Pong" + game += "NoFrameskip-v4" + num_actions = get_num_actions(game) + print(f"Playing {game} with {num_actions} actions") num_agents = NUM_AGENTS total_frames = 4000000 train_device = jax.devices()[0] inference_device = jax.devices()[1] key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) - model = create_model(subkey) + model = create_model(subkey, num_outputs=num_actions) optimizer = create_optimizer(model, learning_rate=LR) del model # jax.device_put(optimizer.target, device=train_device) - train(optimizer, total_frames, num_agents, train_device, inference_device) + train(optimizer, game, total_frames, num_agents, train_device, + inference_device) if __name__ == '__main__': main() diff --git a/examples/ppo/models.py b/examples/ppo/models.py index 385c701b38..60980b5234 100644 --- a/examples/ppo/models.py +++ b/examples/ppo/models.py @@ -9,7 +9,7 @@ class ActorCritic(flax.nn.Module): Note that this is different than the one from "Playing atari with deep reinforcement learning." arxiv.org/abs/1312.5602 (2013) ''' - def apply(self, x): + def apply(self, x, num_outputs): x = x.astype(jnp.float32) / 255. dtype = jnp.float32 x = nn.Conv(x, features=32, kernel_size=(8, 8), @@ -33,15 +33,16 @@ def apply(self, x): x = jnp.maximum(0, x) # network used to both estimate policy (logits) and expected state value # see github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py - logits = nn.Dense(x, features=4, name='logits', dtype=dtype) + logits = nn.Dense(x, features=num_outputs, name='logits', dtype=dtype) policy_log_probabilities = nn.log_softmax(logits) value = nn.Dense(x, features=1, name='value', dtype=dtype) return policy_log_probabilities, value -def create_model(key): +def create_model(key, num_outputs): input_dims = (1, 84, 84, 4) #(minibatch, height, width, stacked frames) - _, initial_par = ActorCritic.init_by_shape(key, [(input_dims, jnp.float32)]) - model = flax.nn.Model(ActorCritic, initial_par) + module = ActorCritic.partial(num_outputs=num_outputs) + _, initial_par = module.init_by_shape(key, [(input_dims, jnp.float32)]) + model = flax.nn.Model(module, initial_par) return model def create_optimizer(model, learning_rate): diff --git a/examples/ppo/remote.py b/examples/ppo/remote.py index a6888be88e..423bfeb86c 100644 --- a/examples/ppo/remote.py +++ b/examples/ppo/remote.py @@ -11,20 +11,20 @@ class RemoteSimulator: emulating Atari in a separate process. An object of this class is created for every agent. """ - def __init__(self): + def __init__(self, game): parent_conn, child_conn = multiprocessing.Pipe() self.proc = multiprocessing.Process( - target=rcv_action_send_exp, args=(child_conn,)) + target=rcv_action_send_exp, args=(child_conn, game)) self.conn = parent_conn self.proc.start() -def rcv_action_send_exp(conn): +def rcv_action_send_exp(conn, game): """Function running on remote agents. Receives action from the main learner, performs one step of simulation and sends back collected experience. """ - env = create_env() + env = create_env(game) while True: obs = env.reset() done = False diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py index fe49725463..949b7a3962 100644 --- a/examples/ppo/test_episodes.py +++ b/examples/ppo/test_episodes.py @@ -8,16 +8,20 @@ from remote import get_state from agent import policy_action -def test(n_episodes: int, model: flax.nn.base.Model, render: bool = False): +def test(n_episodes : int, + model : flax.nn.base.Model, + game : str, + render : bool = False): """Perform a test of the policy in Atari environment. Args: n_episodes: number of full Atari episodes to test on model: the actor-critic model being tested + game: defines the Atari game to test on render: whether to render the test environment Returns: None """ - test_env = create_env() + test_env = create_env(game) if render: test_env = gym.wrappers.Monitor( test_env, "./rendered/" + "ddqn_pong_recording", force=True) diff --git a/examples/ppo/unit_tests.py b/examples/ppo/unit_tests.py index c3699aff79..dabd61c678 100644 --- a/examples/ppo/unit_tests.py +++ b/examples/ppo/unit_tests.py @@ -34,15 +34,23 @@ def test_gae_random(self): from remote import RemoteSimulator, rcv_action_send_exp from env import create_env class TestEnvironmentPreprocessing(absltest.TestCase): + def choose_random_game(self): + games = ['BeamRider', 'Breakout', 'Pong', + 'Qbert', 'Seaquest', 'SpaceInvaders'] + ind = onp.random.choice(len(games)) + return games[ind] + "NoFrameskip-v4" + def test_creation(self): frame_shape = (84, 84, 4) - env = create_env() + game = self.choose_random_game() + env = create_env(game) obs = env.reset() self.assertTrue(obs.shape == frame_shape) def test_step(self): frame_shape = (84, 84, 4) - env = create_env() + game = self.choose_random_game() + env = create_env(game) obs = env.reset() actions = [1, 2, 3, 0] for a in actions: @@ -55,10 +63,14 @@ def test_step(self): #test the model (creation and forward pass) from models import create_model, create_optimizer class TestModel(absltest.TestCase): + def choose_random_outputs(self): + return onp.random.choice([4,5,6,7,8,9]) + def test_model(self): key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) - model = create_model(subkey) + outputs = self.choose_random_outputs() + model = create_model(subkey, outputs) optimizer = create_optimizer(model, learning_rate=1e-3) self.assertTrue(isinstance(model, nn.base.Model)) self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer)) @@ -71,7 +83,5 @@ def test_model(self): onp_testing.assert_almost_equal(sum_probs, onp.ones((test_batch_size, ))) - - if __name__ == '__main__': absltest.main() \ No newline at end of file From 1af5bbbd82d9e7a30c06ee52de0fe4f4ac08574c Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Fri, 18 Sep 2020 08:33:33 +0000 Subject: [PATCH 22/48] Correctly specify the number of frames --- examples/ppo/main.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 6623f23b8a..a4feeda0aa 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -166,14 +166,14 @@ def train( inference_thread.start() t1 = time.time() - for s in range(steps_total // num_agents): + for s in range(steps_total // (num_agents * STEPS_PER_ACTOR)): print(f"\n training loop step {s}") #bookkeeping and testing - if (s + 1) % (10000 // (num_agents*STEPS_PER_ACTOR)) == 0: - print(f" Frames processed {s*num_agents*STEPS_PER_ACTOR}, " + - f"time elapsed {time.time()-t1}") + if (s + 1) % (10000 // (num_agents * STEPS_PER_ACTOR)) == 0: + print(f" Frames processed {s * num_agents * STEPS_PER_ACTOR}, " + + f"time elapsed {time.time() - t1}") t1 = time.time() - if (s + 1) % (2000 // (num_agents*STEPS_PER_ACTOR)) == 0: + if (s + 1) % (2000 // (num_agents * STEPS_PER_ACTOR)) == 0: test(1, optimizer.target, game, render=False) @@ -279,7 +279,7 @@ def main(): num_actions = get_num_actions(game) print(f"Playing {game} with {num_actions} actions") num_agents = NUM_AGENTS - total_frames = 4000000 + total_frames = 10000000 train_device = jax.devices()[0] inference_device = jax.devices()[1] key = jax.random.PRNGKey(0) From f88e45b00df0199db998eb784c42149854c14c1a Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Fri, 18 Sep 2020 14:23:39 +0000 Subject: [PATCH 23/48] Add device_get() for speed as suggested by @jheek --- examples/ppo/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index a4feeda0aa..af80441047 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -117,6 +117,7 @@ def thread_inference( # perform inference # policy_optimizer, step = q1.get() log_probs, values = policy_action(optimizer.target, states) + log_probs, values = jax.device_get((log_probs, values)) probs = onp.exp(onp.array(log_probs)) # print("probs after onp conversion", probs) From 690a9c8986f207068f7f34b5c445f0d75ac1f5e0 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Fri, 18 Sep 2020 15:26:41 +0000 Subject: [PATCH 24/48] Add requirements.txt --- examples/ppo/requirements.txt | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 examples/ppo/requirements.txt diff --git a/examples/ppo/requirements.txt b/examples/ppo/requirements.txt new file mode 100644 index 0000000000..3c55647760 --- /dev/null +++ b/examples/ppo/requirements.txt @@ -0,0 +1,4 @@ +jax +jaxlib +gym +atari-py \ No newline at end of file From 58c4ca08cc29454e234ca13ffe4acc810f739752 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Fri, 18 Sep 2020 17:42:04 +0000 Subject: [PATCH 25/48] Use absl.flags for better hyperparameter handling --- examples/ppo/main.py | 155 +++++++++++++++++++++++-------------------- 1 file changed, 84 insertions(+), 71 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index af80441047..2d48ed29e1 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -7,6 +7,8 @@ from functools import partial from typing import Tuple, List from queue import Queue +from absl import flags +from absl import app import threading @@ -16,6 +18,61 @@ from test_episodes import test from env import get_num_actions +FLAGS = flags.FLAGS + +# default hyperparameters taken from PPO paper and openAI baselines 2 +# https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py + +flags.DEFINE_float( + 'learning_rate', default=2.5e-4, + help=('The learning rate for the Adam optimizer.') +) + +flags.DEFINE_integer( + 'batch_size', default=256, + help=('Batch size for training.') +) + +flags.DEFINE_integer( + 'num_agents', default=8, + help=('Number of agents playing in parallel.') +) + +flags.DEFINE_integer( + 'actor_steps', default=128, + help=('Batch size for training.') +) + +flags.DEFINE_integer( + 'num_epochs', default=3, + help=('Number of epochs per each unroll of the policy.') +) + +flags.DEFINE_float( + 'gamma', default=0.99, + help=('Discount parameter.') +) + +flags.DEFINE_float( + 'lambda_', default=0.95, + help=('Generalized Advantage Estimation parameter.') +) + +flags.DEFINE_float( + 'clip_param', default=0.1, + help=('The PPO clipping parameter used to clamp ratios in loss function.') +) + +flags.DEFINE_float( + 'vf_coeff', default=0.5, + help=('Weighs value function loss in the total loss.') +) + +flags.DEFINE_float( + 'entropy_coeff', default=0.01, + help=('Weighs entropy bonus in the total loss.') +) + @partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1) @jax.jit def gae_advantages(rewards, terminal_masks, values, discount, gae_param): @@ -59,8 +116,6 @@ def train_step(optimizer, trn_data, clip_param, vf_coeff, entropy_coeff): def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): states, actions, old_log_probs, returns, advantages = minibatch shapes = list(map(lambda x : x.shape, minibatch)) - assert(shapes[0] == (BATCH_SIZE, 84, 84, 4)) - assert(all(s == (BATCH_SIZE,) for s in shapes[1:])) log_probs, values = model(states) values = values[:, 0] # convert shapes: (batch, 1) to (batch, ) probs = jnp.exp(log_probs) @@ -78,7 +133,7 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): value_loss = jnp.mean(jnp.square(returns - values), axis=0) return PPO_loss + vf_coeff*value_loss - entropy_coeff*entropy - batch_size = BATCH_SIZE + batch_size = FLAGS.batch_size iterations = trn_data[0].shape[0] // batch_size trn_data = jax.tree_map( lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trn_data) @@ -163,18 +218,18 @@ def train( simulators = [RemoteSimulator(game) for i in range(num_agents)] q1, q2 = Queue(maxsize=1), Queue(maxsize=1) inference_thread = threading.Thread(target=thread_inference, - args=(q1, q2, simulators, STEPS_PER_ACTOR), daemon=True) + args=(q1, q2, simulators, FLAGS.actor_steps), daemon=True) inference_thread.start() t1 = time.time() - for s in range(steps_total // (num_agents * STEPS_PER_ACTOR)): + for s in range(steps_total // (num_agents * FLAGS.actor_steps)): print(f"\n training loop step {s}") #bookkeeping and testing - if (s + 1) % (10000 // (num_agents * STEPS_PER_ACTOR)) == 0: - print(f" Frames processed {s * num_agents * STEPS_PER_ACTOR}, " + + if (s + 1) % (10000 // (num_agents * FLAGS.actor_steps)) == 0: + print(f" Frames processed {s * num_agents * FLAGS.actor_steps}, " + f"time elapsed {time.time() - t1}") t1 = time.time() - if (s + 1) % (2000 // (num_agents * STEPS_PER_ACTOR)) == 0: + if (s + 1) % (20000 // (num_agents * FLAGS.actor_steps)) == 0: test(1, optimizer.target, game, render=False) @@ -189,17 +244,16 @@ def train( all_experiences = q2.get() if s >= 0: #avoid training when there's no data yet obs_shape = (84, 84, 4) - states = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS) + obs_shape, - dtype=onp.float32) - actions = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS), dtype=onp.int32) - rewards = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS), dtype=onp.float32) - values = onp.zeros((STEPS_PER_ACTOR + 1, NUM_AGENTS), dtype=onp.float32) - log_probs = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS), - dtype=onp.float32) - dones = onp.zeros((STEPS_PER_ACTOR, NUM_AGENTS), dtype=onp.float32) - - assert(len(all_experiences) == STEPS_PER_ACTOR + 1) - assert(len(all_experiences[0]) == NUM_AGENTS) + exp_dims = (FLAGS.actor_steps, FLAGS.num_agents) + values_dims = (FLAGS.actor_steps + 1, FLAGS.num_agents) + states = onp.zeros(exp_dims + obs_shape, dtype=onp.float32) + actions = onp.zeros(exp_dims, dtype=onp.int32) + rewards = onp.zeros(exp_dims, dtype=onp.float32) + values = onp.zeros(values_dims, dtype=onp.float32) + log_probs = onp.zeros(exp_dims, dtype=onp.float32) + dones = onp.zeros(exp_dims, dtype=onp.float32) + + assert(len(all_experiences[0]) == FLAGS.num_agents) for t in range(len(all_experiences) - 1): #last only for next_values for agent_id, exp_agent in enumerate(all_experiences[t]): states[t, agent_id, ...] = exp_agent.state @@ -212,85 +266,44 @@ def train( for a in range(num_agents): values[-1, a] = all_experiences[-1][a].value # calculate advantages w. GAE - print(f"nonzero rewards {rewards[onp.nonzero(rewards)]}") - advantages = gae_advantages(rewards, dones, values, DISCOUNT, GAE_PARAM) + advantages = gae_advantages(rewards, dones, values, + FLAGS.gamma, FLAGS.lambda_) returns = advantages + values[:-1, :] - assert(returns.shape == advantages.shape == (STEPS_PER_ACTOR, NUM_AGENTS)) # after preprocessing, concatenate data from all agents trn_data = (states, actions, log_probs, returns, advantages) trn_data = tuple(map( lambda x: onp.reshape(x, - (NUM_AGENTS * STEPS_PER_ACTOR , ) + x.shape[2:]), trn_data) + (FLAGS.num_agents * FLAGS.actor_steps, ) + x.shape[2:]), trn_data) ) print(f"Step {s}: rewards variance {rewards.var()}") - dr = dones.ravel() - print(f"fraction of terminal states {1.-(dr.sum()/dr.shape[0])}") - for e in range(NUM_EPOCHS): #possibly compile this loop inside a jit + for e in range(FLAGS.num_epochs): #possibly compile this loop inside a jit shapes = list(map(lambda x : x.shape, trn_data)) - assert(shapes[0] == (NUM_AGENTS * STEPS_PER_ACTOR, 84, 84, 4)) - assert(all(s == (NUM_AGENTS * STEPS_PER_ACTOR,) for s in shapes[1:])) - permutation = onp.random.permutation(NUM_AGENTS * STEPS_PER_ACTOR) + permutation = onp.random.permutation(num_agents * FLAGS.actor_steps) trn_data = tuple(map(lambda x: x[permutation], trn_data)) optimizer, loss, last_iter_grad_norm = train_step(optimizer, trn_data, - CLIP_PARAM, VF_COEFF, ENTROPY_COEFF) - print(f"Step {s} epoch {e} loss {loss} grad norm {last_iter_grad_norm}") + FLAGS.clip_param, FLAGS.vf_coeff, FLAGS.entropy_coeff) + print(f"epoch {e} loss {loss} grad norm {last_iter_grad_norm}") #end of PPO training return None -# PPO paper and openAI baselines 2 -# https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py -STEPS_PER_ACTOR = 128 -NUM_AGENTS = 8 -NUM_EPOCHS = 3 -BATCH_SIZE = 32 * 8 - -DISCOUNT = 0.99 #usually denoted with \gamma -GAE_PARAM = 0.95 #usually denoted with \lambda - -VF_COEFF = 0.5 #weighs value function loss in total loss -ENTROPY_COEFF = 0.01 # weighs entropy bonus in the total loss - -LR = 2.5e-4 - -CLIP_PARAM = 0.1 - -# openAI baselines 1 -# https://github.com/openai/baselines/blob/master/baselines/ppo1/run_atari.py -# STEPS_PER_ACTOR = 256 -# NUM_AGENTS = 8 -# NUM_EPOCHS = 4 -# BATCH_SIZE = 64 - -# DISCOUNT = 0.99 #usually denoted with \gamma -# GAE_PARAM = 0.95 #usually denoted with \lambda - -# VF_COEFF = 1. #weighs value function loss in total loss -# ENTROPY_COEFF = 0.01 # weighs entropy bonus in the total loss - -# LR = 1e-3 - -# CLIP_PARAM = 0.2 - - - -def main(): +def main(argv): game = "Pong" game += "NoFrameskip-v4" num_actions = get_num_actions(game) print(f"Playing {game} with {num_actions} actions") - num_agents = NUM_AGENTS + num_agents = FLAGS.num_agents total_frames = 10000000 train_device = jax.devices()[0] inference_device = jax.devices()[1] key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) model = create_model(subkey, num_outputs=num_actions) - optimizer = create_optimizer(model, learning_rate=LR) + optimizer = create_optimizer(model, learning_rate=FLAGS.learning_rate) del model # jax.device_put(optimizer.target, device=train_device) train(optimizer, game, total_frames, num_agents, train_device, inference_device) if __name__ == '__main__': - main() + app.run(main) \ No newline at end of file From f53c1df003608da37e683891d1e1b5503fe3724a Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Mon, 21 Sep 2020 09:10:49 +0000 Subject: [PATCH 26/48] Style improvement (comments by @lespeholt and @8bitmp3 & beyond) --- examples/ppo/agent.py | 11 +- examples/ppo/{env.py => env_utils.py} | 32 +++--- examples/ppo/main.py | 107 +++++++++++--------- examples/ppo/models.py | 17 ++-- examples/ppo/remote.py | 29 +++--- examples/ppo/seed_rl_atari_preprocessing.py | 2 +- examples/ppo/test_episodes.py | 18 ++-- examples/ppo/unit_tests.py | 23 ++--- 8 files changed, 133 insertions(+), 106 deletions(-) rename examples/ppo/{env.py => env_utils.py} (64%) diff --git a/examples/ppo/agent.py b/examples/ppo/agent.py index df64825a96..6a75b69e1d 100644 --- a/examples/ppo/agent.py +++ b/examples/ppo/agent.py @@ -1,11 +1,12 @@ +"""Agent's moves (forward pass of the network).""" + import jax import numpy as onp @jax.jit def policy_action(model, state): - """Forward pass of the network. - Potentially the random choice of the action from probabilities can be moved - here with additional rng_key parameter. - """ + """Forward pass of the network.""" + # Potentially the random choice of the action from probabilities can be moved + # here with additional rng_key parameter. out = model(state) - return out \ No newline at end of file + return out diff --git a/examples/ppo/env.py b/examples/ppo/env_utils.py similarity index 64% rename from examples/ppo/env.py rename to examples/ppo/env_utils.py index 2e0f062c82..638a56188d 100644 --- a/examples/ppo/env.py +++ b/examples/ppo/env_utils.py @@ -1,13 +1,17 @@ +"""Utilities for handling the Atari environment.""" + import collections import gym import numpy as onp -from seed_rl_atari_preprocessing import AtariPreprocessing +import seed_rl_atari_preprocessing class ClipRewardEnv(gym.RewardWrapper): - """This class is adatpted from OpenAI baselines + """Adapted from OpenAI baselines. + github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py """ + def __init__(self, env): gym.RewardWrapper.__init__(self, env) @@ -16,10 +20,14 @@ def reward(self, reward): return onp.sign(reward) class FrameStack: - '''Class that wraps an AtariPreprocessing object and implements - stacking of `num_frames` last frames of the game. - ''' - def __init__(self, preproc: AtariPreprocessing, num_frames : int): + """Implements stacking of `num_frames` last frames of the game. + + Wraps an AtariPreprocessing object. + """ + + def __init__(self, + preproc: seed_rl_atari_preprocessing.AtariPreprocessing, + num_frames: int): self.preproc = preproc self.num_frames = num_frames self.frames = collections.deque(maxlen=num_frames) @@ -40,18 +48,18 @@ def _get_array(self): return onp.concatenate(self.frames, axis=-1) def create_env(game : str): - '''Create a FrameStack object that serves as environment for the `game`. - ''' + """Create a FrameStack object that serves as environment for the `game`.""" env = gym.make(game) env = ClipRewardEnv(env) # bin rewards to {-1., 0., 1.} - preproc = AtariPreprocessing(env) + preproc = seed_rl_atari_preprocessing.AtariPreprocessing(env) stack = FrameStack(preproc, num_frames=4) return stack def get_num_actions(game : str): - """Get the number of possible actions of a given Atari game. This determines - the number of outputs in the actor part of the actor-critic model. + """Get the number of possible actions of a given Atari game. + + This determines the number of outputs in the actor part of the + actor-critic model. """ env = gym.make(game) return env.action_space.n - diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 2d48ed29e1..788ba0cdff 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -4,19 +4,18 @@ import numpy as onp import flax import time -from functools import partial -from typing import Tuple, List -from queue import Queue +import functools +import queue from absl import flags from absl import app import threading +from typing import Tuple, List - -from models import create_model, create_optimizer -from agent import policy_action -from remote import RemoteSimulator -from test_episodes import test -from env import get_num_actions +import models +import agent +import remote +import test_episodes +import env_utils FLAGS = flags.FLAGS @@ -73,15 +72,16 @@ help=('Weighs entropy bonus in the total loss.') ) -@partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1) +@functools.partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1) @jax.jit def gae_advantages(rewards, terminal_masks, values, discount, gae_param): - """Use Generalized Advantage Estimation (GAE) to compute advantages - Eqs. (11-12) in PPO paper arXiv: 1707.06347. - Uses key observation that A_{t} = \delta_t + \gamma*\lambda*A_{t+1}. + """Use Generalized Advantage Estimation (GAE) to compute advantages. + + As defined by eqs. (11-12) in PPO paper arXiv: 1707.06347. Implementaion uses + key observation that A_{t} = delta_t + gamma*lambda*A_{t+1}. """ assert rewards.shape[0] + 1 == values.shape[0], ("One more value needed; " - "Eq. (12) in PPO paper requires V(s_{t+1}) to calculate \delta_t") + "Eq. (12) in PPO paper requires V(s_{t+1}) to calculate delta_t") advantages, gae = [], 0. for t in reversed(range(len(rewards))): # masks to set next state value to 0 for terminal states @@ -95,19 +95,23 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): @jax.jit def train_step(optimizer, trn_data, clip_param, vf_coeff, entropy_coeff): - """Compilable train step. Runs an entire epoch of training (i.e. the loop over - minibatches within an epoch is included here for performance reasons. + """Compilable train step. + + Runs an entire epoch of training (i.e. the loop over + minibatches within an epoch is included here for performance reasons). + Args: optimizer: optimizer for the actor-critic model trn_data: Tuple of the following five elements forming the experience: states: shape (steps_per_agent*num_agents, 84, 84, 4) actions: shape (steps_per_agent*num_agents, 84, 84, 4) - old_log_probs: shape (steps_per_agent*num_agentss, ) - returns: shape (steps_per_agent*num_agentss, ) + old_log_probs: shape (steps_per_agent*num_agents, ) + returns: shape (steps_per_agent*num_agents, ) advantages: (steps_per_agent*num_agents, ) clip_param: the PPO clipping parameter used to clamp ratios in loss function vf_coeff: weighs value function loss in total loss entropy_coeff: weighs entropy bonus in the total loss + Returns: optimizer: new optimizer after the parameters update loss: loss summed over training steps @@ -140,8 +144,8 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): loss = 0. for batch in zip(*trn_data): grad_fn = jax.value_and_grad(loss_fn) - l, grad = grad_fn(optimizer.target, batch, clip_param, vf_coeff, - entropy_coeff) + l, grad = grad_fn(optimizer.target, batch, clip_param, vf_coeff, + entropy_coeff) loss += l optimizer = optimizer.apply_gradient(grad) grad_norm = sum(jnp.square(g).sum() for g in jax.tree_leaves(grad)) @@ -149,17 +153,16 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): def thread_inference( - q1 : Queue, - q2: Queue, - simulators : List[RemoteSimulator], - steps_per_actor : int): - """Worker function for a separate thread used for inference and running - the simulators in order to maximize the GPU/TPU usage. Runs - `steps_per_actor` time steps of the game for each of the `simulators`. + policy_q: queue.Queue, + experience_q: queue.Queue, + simulators: List[remote.RemoteSimulator], + steps_per_actor: int): + """Worker function for a separate inference thread. + + Runs `steps_per_actor` time steps of the game for each of the `simulators`. """ - while(True): - optimizer, step = q1.get() + optimizer, step = policy_q.get() all_experience = [] for _ in range(steps_per_actor + 1): # +1 to get one more value # needed for GAE @@ -170,8 +173,8 @@ def thread_inference( states = onp.concatenate(states, axis=0) # perform inference - # policy_optimizer, step = q1.get() - log_probs, values = policy_action(optimizer.target, states) + # policy_optimizer, step = policy_q.get() + log_probs, values = agent.policy_action(optimizer.target, states) log_probs, values = jax.device_get((log_probs, values)) probs = onp.exp(onp.array(log_probs)) @@ -194,17 +197,18 @@ def thread_inference( experiences.append(sample) all_experience.append(experiences) - q2.put(all_experience) + experience_q.put(all_experience) def train( - optimizer : flax.optim.base.Optimizer, - game : str, - steps_total : int, - num_agents : int, + optimizer: flax.optim.base.Optimizer, + game: str, + steps_total: int, + num_agents: int, train_device, inference_device): """Main training loop. + Args: optimizer: optimizer for the actor-critic model game: string specifying the Atari game from Gym package @@ -212,13 +216,17 @@ def train( num_agents: number of separate processes with agents running the envs train_device : device used for training inference_device : device used for inference + Returns: None """ - simulators = [RemoteSimulator(game) for i in range(num_agents)] - q1, q2 = Queue(maxsize=1), Queue(maxsize=1) - inference_thread = threading.Thread(target=thread_inference, - args=(q1, q2, simulators, FLAGS.actor_steps), daemon=True) + simulators = [remote.RemoteSimulator(game) for i in range(num_agents)] + policy_q = queue.Queue(maxsize=1) + experience_q = queue.Queue(maxsize=1) + inference_thread = threading.Thread( + target=thread_inference, + args=(policy_q, experience_q, simulators, FLAGS.actor_steps), + daemon=True) inference_thread.start() t1 = time.time() @@ -230,18 +238,18 @@ def train( f"time elapsed {time.time() - t1}") t1 = time.time() if (s + 1) % (20000 // (num_agents * FLAGS.actor_steps)) == 0: - test(1, optimizer.target, game, render=False) + test_episodes.test(1, optimizer.target, game, render=False) # send the up-to-date policy model and current step to inference thread step = s*num_agents - q1.put((optimizer, step)) + policy_q.put((optimizer, step)) # perform PPO training # experience is a list of list of tuples, here we preprocess this data to # get required input for GAE and then for training # initial version, needs improvement in terms of speed & readability - all_experiences = q2.get() + all_experiences = experience_q.get() if s >= 0: #avoid training when there's no data yet obs_shape = (84, 84, 4) exp_dims = (FLAGS.actor_steps, FLAGS.num_agents) @@ -272,9 +280,8 @@ def train( # after preprocessing, concatenate data from all agents trn_data = (states, actions, log_probs, returns, advantages) trn_data = tuple(map( - lambda x: onp.reshape(x, - (FLAGS.num_agents * FLAGS.actor_steps, ) + x.shape[2:]), trn_data) - ) + lambda x: onp.reshape( + x, (FLAGS.num_agents * FLAGS.actor_steps, ) + x.shape[2:]), trn_data)) print(f"Step {s}: rewards variance {rewards.var()}") for e in range(FLAGS.num_epochs): #possibly compile this loop inside a jit shapes = list(map(lambda x : x.shape, trn_data)) @@ -290,7 +297,7 @@ def train( def main(argv): game = "Pong" game += "NoFrameskip-v4" - num_actions = get_num_actions(game) + num_actions = env_utils.get_num_actions(game) print(f"Playing {game} with {num_actions} actions") num_agents = FLAGS.num_agents total_frames = 10000000 @@ -298,12 +305,12 @@ def main(argv): inference_device = jax.devices()[1] key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) - model = create_model(subkey, num_outputs=num_actions) - optimizer = create_optimizer(model, learning_rate=FLAGS.learning_rate) + model = models.create_model(subkey, num_outputs=num_actions) + optimizer = models.create_optimizer(model, learning_rate=FLAGS.learning_rate) del model # jax.device_put(optimizer.target, device=train_device) train(optimizer, game, total_frames, num_agents, train_device, inference_device) if __name__ == '__main__': - app.run(main) \ No newline at end of file + app.run(main) diff --git a/examples/ppo/models.py b/examples/ppo/models.py index 60980b5234..7c6f92e624 100644 --- a/examples/ppo/models.py +++ b/examples/ppo/models.py @@ -1,15 +1,20 @@ +"""Class and functions to define and initialize the actor-critic model.""" + import flax from flax import nn import jax.numpy as jnp class ActorCritic(flax.nn.Module): - ''' - Architecture from "Human-level control through deep reinforcement learning." - Nature 518, no. 7540 (2015): 529-533. - Note that this is different than the one from "Playing atari with deep - reinforcement learning." arxiv.org/abs/1312.5602 (2013) - ''' + """Class defining the actor-critic model.""" + def apply(self, x, num_outputs): + """Define the convolutional network architecture. + + Architecture originates from "Human-level control through deep reinforcement + learning.", Nature 518, no. 7540 (2015): 529-533. + Note that this is different than the one from "Playing atari with deep + reinforcement learning." arxiv.org/abs/1312.5602 (2013) + """ x = x.astype(jnp.float32) / 255. dtype = jnp.float32 x = nn.Conv(x, features=32, kernel_size=(8, 8), diff --git a/examples/ppo/remote.py b/examples/ppo/remote.py index 423bfeb86c..836a85070c 100644 --- a/examples/ppo/remote.py +++ b/examples/ppo/remote.py @@ -1,16 +1,20 @@ +"""Utilities for running the agents in separate processes.""" + import multiprocessing import numpy as onp -from collections import namedtuple -from env import create_env +import collections + +import env_utils -exp_tuple = namedtuple('exp_tuple', - ['state', 'action', 'reward', 'value', 'log_prob', 'done']) +exp_tuple = collections.namedtuple( + 'exp_tuple', ['state', 'action', 'reward', 'value', 'log_prob', 'done']) class RemoteSimulator: - """Class that wraps basic functionality needed for an agent - emulating Atari in a separate process. + """Wrap functionality for an agent emulating Atari in a separate process. + An object of this class is created for every agent. """ + def __init__(self, game): parent_conn, child_conn = multiprocessing.Pipe() self.proc = multiprocessing.Process( @@ -20,11 +24,12 @@ def __init__(self, game): def rcv_action_send_exp(conn, game): - """Function running on remote agents. Receives action from - the main learner, performs one step of simulation and - sends back collected experience. + """Run the remote agents. + + Receive action from the main learner, perform one step of simulation and + send back collected experience. """ - env = create_env(game) + env = env_utils.create_env(game) while True: obs = env.reset() done = False @@ -42,8 +47,6 @@ def rcv_action_send_exp(conn, game): def get_state(observation): - """Covert observation from Atari environment into a NumPy array and add - a batch dimension. - """ + """Convert Atari env observation into a NumPy array, add batch dimension.""" state = onp.array(observation) return state[None, ...] diff --git a/examples/ppo/seed_rl_atari_preprocessing.py b/examples/ppo/seed_rl_atari_preprocessing.py index 2e5982d760..f51faad4fa 100644 --- a/examples/ppo/seed_rl_atari_preprocessing.py +++ b/examples/ppo/seed_rl_atari_preprocessing.py @@ -13,7 +13,7 @@ # limitations under the License. """A class implementing minimal Atari 2600 preprocessing. -Adapted from Dopamine. +Adapted from SEED RL, originally adapted from Dopamine. """ from gym.spaces.box import Box diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py index 949b7a3962..added4a549 100644 --- a/examples/ppo/test_episodes.py +++ b/examples/ppo/test_episodes.py @@ -1,37 +1,41 @@ +"""Test policy by playing a full Atari game.""" + import time import itertools import gym import flax import numpy as onp -from env import create_env -from remote import get_state -from agent import policy_action +import env_utils +import remote +import agent def test(n_episodes : int, model : flax.nn.base.Model, game : str, render : bool = False): """Perform a test of the policy in Atari environment. + Args: n_episodes: number of full Atari episodes to test on model: the actor-critic model being tested game: defines the Atari game to test on render: whether to render the test environment + Returns: None """ - test_env = create_env(game) + test_env = env_utils.create_env(game) if render: test_env = gym.wrappers.Monitor( test_env, "./rendered/" + "ddqn_pong_recording", force=True) all_probabilities = [] for e in range(n_episodes): obs = test_env.reset() - state = get_state(obs) + state = remote.get_state(obs) total_reward = 0.0 for t in itertools.count(): - log_probs, _ = policy_action(model, state) + log_probs, _ = agent.policy_action(model, state) probs = onp.exp(onp.array(log_probs, dtype=onp.float32)) probabilities = probs[0] / probs[0].sum() all_probabilities.append(probabilities) @@ -42,7 +46,7 @@ def test(n_episodes : int, test_env.render() time.sleep(0.01) if not done: - next_state = get_state(obs) + next_state = remote.get_state(obs) else: next_state = None state = next_state diff --git a/examples/ppo/unit_tests.py b/examples/ppo/unit_tests.py index dabd61c678..4c90d6a674 100644 --- a/examples/ppo/unit_tests.py +++ b/examples/ppo/unit_tests.py @@ -1,14 +1,15 @@ - import jax import flax from flax import nn import numpy as onp - import numpy.testing as onp_testing from absl.testing import absltest +import main +import env_utils +import models + #test GAE -from main import gae_advantages class TestGAE(absltest.TestCase): def test_gae_random(self): # create random data, simulating 4 parallel envs and 20 time_steps @@ -19,7 +20,8 @@ def test_gae_random(self): values = onp.random.random(size=(steps + 1, envs)) discount = 0.99 gae_param = 0.95 - adv = gae_advantages(rewards, terminal_masks, values, discount, gae_param) + adv = main.gae_advantages(rewards, terminal_masks, values, discount, + gae_param) self.assertEqual(adv.shape, (steps, envs)) # test the property A_{t} = \delta_t + \gamma*\lambda*A_{t+1} # for each agent separately @@ -31,8 +33,6 @@ def test_gae_random(self): onp_testing.assert_almost_equal(lhs, rhs) #test environment and preprocessing -from remote import RemoteSimulator, rcv_action_send_exp -from env import create_env class TestEnvironmentPreprocessing(absltest.TestCase): def choose_random_game(self): games = ['BeamRider', 'Breakout', 'Pong', @@ -43,14 +43,14 @@ def choose_random_game(self): def test_creation(self): frame_shape = (84, 84, 4) game = self.choose_random_game() - env = create_env(game) + env = env_utils.create_env(game) obs = env.reset() self.assertTrue(obs.shape == frame_shape) def test_step(self): frame_shape = (84, 84, 4) game = self.choose_random_game() - env = create_env(game) + env = env_utils.create_env(game) obs = env.reset() actions = [1, 2, 3, 0] for a in actions: @@ -61,7 +61,6 @@ def test_step(self): self.assertTrue(isinstance(info, dict)) #test the model (creation and forward pass) -from models import create_model, create_optimizer class TestModel(absltest.TestCase): def choose_random_outputs(self): return onp.random.choice([4,5,6,7,8,9]) @@ -70,8 +69,8 @@ def test_model(self): key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) outputs = self.choose_random_outputs() - model = create_model(subkey, outputs) - optimizer = create_optimizer(model, learning_rate=1e-3) + model = models.create_model(subkey, outputs) + optimizer = models.create_optimizer(model, learning_rate=1e-3) self.assertTrue(isinstance(model, nn.base.Model)) self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer)) test_batch_size, obs_shape = 10, (84, 84, 4) @@ -84,4 +83,4 @@ def test_model(self): if __name__ == '__main__': - absltest.main() \ No newline at end of file + absltest.main() From 2b10c33261f6fc34fd6b8081a236ac366c8ec137 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Mon, 21 Sep 2020 13:51:57 +0000 Subject: [PATCH 27/48] Don't bin rewards during testing --- examples/ppo/env_utils.py | 13 +++++++------ examples/ppo/remote.py | 6 +++--- examples/ppo/test_episodes.py | 2 +- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/ppo/env_utils.py b/examples/ppo/env_utils.py index 638a56188d..0150f0a5fe 100644 --- a/examples/ppo/env_utils.py +++ b/examples/ppo/env_utils.py @@ -24,7 +24,7 @@ class FrameStack: Wraps an AtariPreprocessing object. """ - + def __init__(self, preproc: seed_rl_atari_preprocessing.AtariPreprocessing, num_frames: int): @@ -47,18 +47,19 @@ def _get_array(self): assert len(self.frames) == self.num_frames return onp.concatenate(self.frames, axis=-1) -def create_env(game : str): +def create_env(game: str, clip_rewards: bool): """Create a FrameStack object that serves as environment for the `game`.""" env = gym.make(game) - env = ClipRewardEnv(env) # bin rewards to {-1., 0., 1.} + if clip_rewards: + env = ClipRewardEnv(env) # bin rewards to {-1., 0., 1.} preproc = seed_rl_atari_preprocessing.AtariPreprocessing(env) stack = FrameStack(preproc, num_frames=4) return stack -def get_num_actions(game : str): +def get_num_actions(game: str): """Get the number of possible actions of a given Atari game. - - This determines the number of outputs in the actor part of the + + This determines the number of outputs in the actor part of the actor-critic model. """ env = gym.make(game) diff --git a/examples/ppo/remote.py b/examples/ppo/remote.py index 836a85070c..b6494e87b2 100644 --- a/examples/ppo/remote.py +++ b/examples/ppo/remote.py @@ -14,7 +14,7 @@ class RemoteSimulator: An object of this class is created for every agent. """ - + def __init__(self, game): parent_conn, child_conn = multiprocessing.Pipe() self.proc = multiprocessing.Process( @@ -25,11 +25,11 @@ def __init__(self, game): def rcv_action_send_exp(conn, game): """Run the remote agents. - + Receive action from the main learner, perform one step of simulation and send back collected experience. """ - env = env_utils.create_env(game) + env = env_utils.create_env(game, clip_rewards=True) while True: obs = env.reset() done = False diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py index added4a549..4391c6b7f2 100644 --- a/examples/ppo/test_episodes.py +++ b/examples/ppo/test_episodes.py @@ -25,7 +25,7 @@ def test(n_episodes : int, Returns: None """ - test_env = env_utils.create_env(game) + test_env = env_utils.create_env(game, clip_rewards=False) if render: test_env = gym.wrappers.Monitor( test_env, "./rendered/" + "ddqn_pong_recording", force=True) From da0ec77723ad7f3ef19642b3234e599bc0032bbe Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Mon, 21 Sep 2020 15:13:31 +0000 Subject: [PATCH 28/48] Update testing requirements --- examples/ppo/README.md | 2 +- examples/ppo/main.py | 2 +- examples/ppo/requirements.txt | 5 +++-- examples/ppo/test_episodes.py | 3 ++- setup.py | 2 ++ 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/ppo/README.md b/examples/ppo/README.md index 3f4af121ed..5def2ceecf 100644 --- a/examples/ppo/README.md +++ b/examples/ppo/README.md @@ -4,7 +4,7 @@ Uses the Proximal Policy Optimization algorithm ([Schulman et al., 2017](https:/ ## Requirements -* This example depends on the `gym` and `atari-py` packages in addition to `jax` and `flax`. +* This example depends on the `gym`, `opencv-python` and `atari-py` packages in addition to `jax` and `flax`. ## How to run diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 788ba0cdff..7f345fbaa0 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -238,7 +238,7 @@ def train( f"time elapsed {time.time() - t1}") t1 = time.time() if (s + 1) % (20000 // (num_agents * FLAGS.actor_steps)) == 0: - test_episodes.test(1, optimizer.target, game, render=False) + test_episodes.policy_test(1, optimizer.target, game, render=False) # send the up-to-date policy model and current step to inference thread diff --git a/examples/ppo/requirements.txt b/examples/ppo/requirements.txt index 3c55647760..69d6538acf 100644 --- a/examples/ppo/requirements.txt +++ b/examples/ppo/requirements.txt @@ -1,4 +1,5 @@ +atari-py +gym jax jaxlib -gym -atari-py \ No newline at end of file +opencv-python \ No newline at end of file diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py index 4391c6b7f2..4afa11851c 100644 --- a/examples/ppo/test_episodes.py +++ b/examples/ppo/test_episodes.py @@ -10,7 +10,8 @@ import remote import agent -def test(n_episodes : int, +def policy_test( + n_episodes : int, model : flax.nn.base.Model, game : str, render : bool = False): diff --git a/setup.py b/setup.py index 3fdec275fc..babbdcf69d 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,9 @@ ] tests_require = [ + "gym", "jaxlib", + "opencv-python", "pytest", "pytest-cov", "pytest-xdist==1.34.0", # upgrading to 2.0 broke tests, need to investigate From 9c72f00e64d14f632138b838c6c00dba6c14c3db Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Mon, 21 Sep 2020 17:57:15 +0000 Subject: [PATCH 29/48] Implement the decay of the clip parameter and learning rate --- examples/ppo/main.py | 75 ++++++++++++++++++++++++++------------------ 1 file changed, 44 insertions(+), 31 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 7f345fbaa0..f654bd39ae 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -23,53 +23,59 @@ # https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py flags.DEFINE_float( - 'learning_rate', default=2.5e-4, - help=('The learning rate for the Adam optimizer.') + 'learning_rate', default=2.5e-4, + help=('The learning rate for the Adam optimizer.') ) flags.DEFINE_integer( - 'batch_size', default=256, - help=('Batch size for training.') + 'batch_size', default=256, + help=('Batch size for training.') ) flags.DEFINE_integer( - 'num_agents', default=8, - help=('Number of agents playing in parallel.') + 'num_agents', default=8, + help=('Number of agents playing in parallel.') ) flags.DEFINE_integer( - 'actor_steps', default=128, - help=('Batch size for training.') + 'actor_steps', default=128, + help=('Batch size for training.') ) flags.DEFINE_integer( - 'num_epochs', default=3, - help=('Number of epochs per each unroll of the policy.') + 'num_epochs', default=3, + help=('Number of epochs per each unroll of the policy.') ) flags.DEFINE_float( - 'gamma', default=0.99, - help=('Discount parameter.') + 'gamma', default=0.99, + help=('Discount parameter.') ) flags.DEFINE_float( - 'lambda_', default=0.95, - help=('Generalized Advantage Estimation parameter.') + 'lambda_', default=0.95, + help=('Generalized Advantage Estimation parameter.') ) flags.DEFINE_float( - 'clip_param', default=0.1, - help=('The PPO clipping parameter used to clamp ratios in loss function.') + 'clip_param', default=0.1, + help=('The PPO clipping parameter used to clamp ratios in loss function.') ) flags.DEFINE_float( - 'vf_coeff', default=0.5, - help=('Weighs value function loss in the total loss.') + 'vf_coeff', default=0.5, + help=('Weighs value function loss in the total loss.') ) flags.DEFINE_float( - 'entropy_coeff', default=0.01, - help=('Weighs entropy bonus in the total loss.') + 'entropy_coeff', default=0.01, + help=('Weighs entropy bonus in the total loss.') +) + +flags.DEFINE_boolean( + 'decaying_lr_and_clip_param', default=True, + help=(('Linearly decay learning rate and clipping parameter to zero during ' + 'the training.')) ) @functools.partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1) @@ -77,7 +83,7 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): """Use Generalized Advantage Estimation (GAE) to compute advantages. - As defined by eqs. (11-12) in PPO paper arXiv: 1707.06347. Implementaion uses + As defined by eqs. (11-12) in PPO paper arXiv: 1707.06347. Implementation uses key observation that A_{t} = delta_t + gamma*lambda*A_{t+1}. """ assert rewards.shape[0] + 1 == values.shape[0], ("One more value needed; " @@ -94,9 +100,9 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): return jnp.array(advantages) @jax.jit -def train_step(optimizer, trn_data, clip_param, vf_coeff, entropy_coeff): +def train_step(optimizer, trn_data, clip_param, vf_coeff, entropy_coeff, lr): """Compilable train step. - + Runs an entire epoch of training (i.e. the loop over minibatches within an epoch is included here for performance reasons). @@ -111,6 +117,8 @@ def train_step(optimizer, trn_data, clip_param, vf_coeff, entropy_coeff): clip_param: the PPO clipping parameter used to clamp ratios in loss function vf_coeff: weighs value function loss in total loss entropy_coeff: weighs entropy bonus in the total loss + lr: learning rate, varies between optimization steps + if decaying_lr_and_clip_param is set to true Returns: optimizer: new optimizer after the parameters update @@ -144,10 +152,10 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): loss = 0. for batch in zip(*trn_data): grad_fn = jax.value_and_grad(loss_fn) - l, grad = grad_fn(optimizer.target, batch, clip_param, vf_coeff, + l, grad = grad_fn(optimizer.target, batch, clip_param, vf_coeff, entropy_coeff) loss += l - optimizer = optimizer.apply_gradient(grad) + optimizer = optimizer.apply_gradient(grad, learning_rate=lr) grad_norm = sum(jnp.square(g).sum() for g in jax.tree_leaves(grad)) return optimizer, loss, grad_norm @@ -158,7 +166,7 @@ def thread_inference( simulators: List[remote.RemoteSimulator], steps_per_actor: int): """Worker function for a separate inference thread. - + Runs `steps_per_actor` time steps of the game for each of the `simulators`. """ while(True): @@ -229,8 +237,8 @@ def train( daemon=True) inference_thread.start() t1 = time.time() - - for s in range(steps_total // (num_agents * FLAGS.actor_steps)): + loop_steps = steps_total // (num_agents * FLAGS.actor_steps) + for s in range(loop_steps): print(f"\n training loop step {s}") #bookkeeping and testing if (s + 1) % (10000 // (num_agents * FLAGS.actor_steps)) == 0: @@ -240,7 +248,10 @@ def train( if (s + 1) % (20000 // (num_agents * FLAGS.actor_steps)) == 0: test_episodes.policy_test(1, optimizer.target, game, render=False) - + if FLAGS.decaying_lr_and_clip_param: + alpha = 1. - s/loop_steps + else: + alpha = 1. # send the up-to-date policy model and current step to inference thread step = s*num_agents policy_q.put((optimizer, step)) @@ -283,12 +294,14 @@ def train( lambda x: onp.reshape( x, (FLAGS.num_agents * FLAGS.actor_steps, ) + x.shape[2:]), trn_data)) print(f"Step {s}: rewards variance {rewards.var()}") + lr = FLAGS.learning_rate * alpha + clip_param = FLAGS.clip_param * alpha for e in range(FLAGS.num_epochs): #possibly compile this loop inside a jit shapes = list(map(lambda x : x.shape, trn_data)) permutation = onp.random.permutation(num_agents * FLAGS.actor_steps) trn_data = tuple(map(lambda x: x[permutation], trn_data)) optimizer, loss, last_iter_grad_norm = train_step(optimizer, trn_data, - FLAGS.clip_param, FLAGS.vf_coeff, FLAGS.entropy_coeff) + clip_param, FLAGS.vf_coeff, FLAGS.entropy_coeff, lr) print(f"epoch {e} loss {loss} grad norm {last_iter_grad_norm}") #end of PPO training @@ -300,7 +313,7 @@ def main(argv): num_actions = env_utils.get_num_actions(game) print(f"Playing {game} with {num_actions} actions") num_agents = FLAGS.num_agents - total_frames = 10000000 + total_frames = 40000000 train_device = jax.devices()[0] inference_device = jax.devices()[1] key = jax.random.PRNGKey(0) From f398660136ffbf796f8a1a6e191605887b259429 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Tue, 22 Sep 2020 11:57:43 +0000 Subject: [PATCH 30/48] Models: jnp.maximum->nn.relu and use dtype everywhere --- examples/ppo/models.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/examples/ppo/models.py b/examples/ppo/models.py index 7c6f92e624..8708d9dd87 100644 --- a/examples/ppo/models.py +++ b/examples/ppo/models.py @@ -9,33 +9,29 @@ class ActorCritic(flax.nn.Module): def apply(self, x, num_outputs): """Define the convolutional network architecture. - - Architecture originates from "Human-level control through deep reinforcement + + Architecture originates from "Human-level control through deep reinforcement learning.", Nature 518, no. 7540 (2015): 529-533. Note that this is different than the one from "Playing atari with deep reinforcement learning." arxiv.org/abs/1312.5602 (2013) """ - x = x.astype(jnp.float32) / 255. dtype = jnp.float32 + x = x.astype(dtype) / 255. x = nn.Conv(x, features=32, kernel_size=(8, 8), strides=(4, 4), name='conv1', dtype=dtype) - # x = nn.relu(x) - x = jnp.maximum(0, x) + x = nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(4, 4), strides=(2, 2), name='conv2', dtype=dtype) - # x = nn.relu(x) - x = jnp.maximum(0, x) + x = nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(3, 3), strides=(1, 1), name='conv3', dtype=dtype) - # x = nn.relu(x) - x = jnp.maximum(0, x) + x = nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=512, name='hidden', dtype=dtype) - # x = nn.relu(x) - x = jnp.maximum(0, x) + x = nn.relu(x) # network used to both estimate policy (logits) and expected state value # see github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py logits = nn.Dense(x, features=num_outputs, name='logits', dtype=dtype) From 19dbbc27069d0fe0d91d1bb17b3a2bfb9006993c Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Tue, 22 Sep 2020 12:35:19 +0000 Subject: [PATCH 31/48] Append and then reverse instead of pushing in front in GAE estimation --- examples/ppo/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index f654bd39ae..19f35288f0 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -96,7 +96,8 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): # masks[t] to ensure that values before and after a terminal state # are independent of each other gae = delta + discount * gae_param * terminal_masks[t] * gae - advantages = [gae] + advantages + advantages.append(gae) + advantages = advantages[::-1] return jnp.array(advantages) @jax.jit From 518a7f61b86e6015d09c3314077b8b75003be205 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Wed, 23 Sep 2020 13:18:56 +0000 Subject: [PATCH 32/48] Unit & policy test improvements --- examples/ppo/main.py | 2 +- examples/ppo/test_episodes.py | 19 +++------- examples/ppo/unit_tests.py | 68 +++++++++++++++++++++++++---------- 3 files changed, 54 insertions(+), 35 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 19f35288f0..b5c88cfbcc 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -247,7 +247,7 @@ def train( f"time elapsed {time.time() - t1}") t1 = time.time() if (s + 1) % (20000 // (num_agents * FLAGS.actor_steps)) == 0: - test_episodes.policy_test(1, optimizer.target, game, render=False) + test_episodes.policy_test(1, optimizer.target, game) if FLAGS.decaying_lr_and_clip_param: alpha = 1. - s/loop_steps diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py index 4afa11851c..4c8aa76e27 100644 --- a/examples/ppo/test_episodes.py +++ b/examples/ppo/test_episodes.py @@ -1,8 +1,6 @@ """Test policy by playing a full Atari game.""" -import time import itertools -import gym import flax import numpy as onp @@ -11,27 +9,22 @@ import agent def policy_test( - n_episodes : int, - model : flax.nn.base.Model, - game : str, - render : bool = False): + n_episodes: int, + model: flax.nn.base.Model, + game: str): """Perform a test of the policy in Atari environment. Args: n_episodes: number of full Atari episodes to test on model: the actor-critic model being tested game: defines the Atari game to test on - render: whether to render the test environment Returns: None """ test_env = env_utils.create_env(game, clip_rewards=False) - if render: - test_env = gym.wrappers.Monitor( - test_env, "./rendered/" + "ddqn_pong_recording", force=True) all_probabilities = [] - for e in range(n_episodes): + for _ in range(n_episodes): obs = test_env.reset() state = remote.get_state(obs) total_reward = 0.0 @@ -43,9 +36,6 @@ def policy_test( action = onp.random.choice(probs.shape[1], p=probabilities) obs, reward, done, _ = test_env.step(action) total_reward += reward - if render: - test_env.render() - time.sleep(0.01) if not done: next_state = remote.get_state(obs) else: @@ -53,7 +43,6 @@ def policy_test( state = next_state if done: all_probabilities = onp.stack(all_probabilities, axis=0) - print(f"all_probabilities shape {all_probabilities.shape}") vars = onp.var(all_probabilities, axis=0) print(f"------> TEST FINISHED: reward {total_reward} in {t} steps") print(f"Variance of probabilities across encuntered states {vars}") diff --git a/examples/ppo/unit_tests.py b/examples/ppo/unit_tests.py index 4c90d6a674..6bf94d3341 100644 --- a/examples/ppo/unit_tests.py +++ b/examples/ppo/unit_tests.py @@ -9,9 +9,9 @@ import env_utils import models -#test GAE +# test GAE class TestGAE(absltest.TestCase): - def test_gae_random(self): + def test_gae_shape_on_random(self): # create random data, simulating 4 parallel envs and 20 time_steps envs, steps = 10, 100 rewards = onp.random.choice([-1., 0., 1.], size=(steps, envs), @@ -20,37 +20,40 @@ def test_gae_random(self): values = onp.random.random(size=(steps + 1, envs)) discount = 0.99 gae_param = 0.95 - adv = main.gae_advantages(rewards, terminal_masks, values, discount, + adv = main.gae_advantages(rewards, terminal_masks, values, discount, gae_param) self.assertEqual(adv.shape, (steps, envs)) - # test the property A_{t} = \delta_t + \gamma*\lambda*A_{t+1} - # for each agent separately - for e in range(envs): - for t in range(steps-1): - delta = rewards[t, e] + discount * values[t+1, e] - values[t, e] - lhs = adv[t, e] - rhs = delta + discount * gae_param * adv[t+1, e] - onp_testing.assert_almost_equal(lhs, rhs) - -#test environment and preprocessing + def test_gae_hardcoded(self): + #test on small example that can be verified by hand + rewards = onp.array([[1., 0.], [0., 0.], [-1., 1.]]) + #one of the two episodes terminated in the middle + terminal_masks = onp.array([[1., 1.], [0., 1.], [1., 1.]]) + values = onp.array([[1., 1.], [1., 1.], [1., 1.], [1., 1.]]) + discount = 0.5 + gae_param = 0.25 + correct_gae = onp.array([[0.375, -0.5546875], [-1., -0.4375], [-1.5, 0.5]]) + actual_gae = main.gae_advantages(rewards, terminal_masks, values, discount, + gae_param) + onp_testing.assert_allclose(actual_gae, correct_gae) +# test environment and preprocessing class TestEnvironmentPreprocessing(absltest.TestCase): def choose_random_game(self): games = ['BeamRider', 'Breakout', 'Pong', - 'Qbert', 'Seaquest', 'SpaceInvaders'] + 'Qbert', 'Seaquest', 'SpaceInvaders'] ind = onp.random.choice(len(games)) return games[ind] + "NoFrameskip-v4" def test_creation(self): frame_shape = (84, 84, 4) game = self.choose_random_game() - env = env_utils.create_env(game) + env = env_utils.create_env(game, clip_rewards=True) obs = env.reset() self.assertTrue(obs.shape == frame_shape) def test_step(self): frame_shape = (84, 84, 4) game = self.choose_random_game() - env = env_utils.create_env(game) + env = env_utils.create_env(game, clip_rewards=False) obs = env.reset() actions = [1, 2, 3, 0] for a in actions: @@ -60,10 +63,10 @@ def test_step(self): self.assertTrue(isinstance(done, bool)) self.assertTrue(isinstance(info, dict)) -#test the model (creation and forward pass) +# test the model (creation and forward pass) class TestModel(absltest.TestCase): def choose_random_outputs(self): - return onp.random.choice([4,5,6,7,8,9]) + return onp.random.choice([4, 5, 6, 7, 8, 9]) def test_model(self): key = jax.random.PRNGKey(0) @@ -79,8 +82,35 @@ def test_model(self): self.assertTrue(values.shape == (test_batch_size, 1)) sum_probs = onp.sum(onp.exp(log_probs), axis=1) self.assertTrue(sum_probs.shape == (test_batch_size, )) - onp_testing.assert_almost_equal(sum_probs, onp.ones((test_batch_size, ))) + onp_testing.assert_allclose(sum_probs, onp.ones((test_batch_size, )), + atol=1e-6) + +# test one optimization step +class TestOptimizationStep(absltest.TestCase): + def generate_random_data(self, num_actions): + data_len = 256 # equal to one default-sized batch + state_shape = (84, 84, 4) + states = onp.random.randint(0, 255, size=((data_len, ) + state_shape)) + actions = onp.random.choice(num_actions, size=data_len) + old_log_probs = onp.random.random(size=data_len) + returns = onp.random.random(size=data_len) + advantages = onp.random.random(size=data_len) + return states, actions, old_log_probs, returns, advantages + def test_optimization_step(self): + num_outputs = 4 + trn_data = self.generate_random_data(num_actions=num_outputs) + clip_param = 0.1 + vf_coeff = 0.5 + entropy_coeff = 0.01 + lr = 2.5e-4 + key = jax.random.PRNGKey(0) + key, subkey = jax.random.split(key) + model = models.create_model(subkey, num_outputs) + optimizer = models.create_optimizer(model, learning_rate=lr) + optimizer, _, _ = main.train_step( + optimizer, trn_data, clip_param, vf_coeff, entropy_coeff, lr) + self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer)) if __name__ == '__main__': absltest.main() From 8ef4493728cbfa8e8874ca489ded7d39bd6198c8 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Wed, 23 Sep 2020 15:24:43 +0000 Subject: [PATCH 33/48] Fix conflict in setup.py --- setup.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index babbdcf69d..4a8a1e3762 100644 --- a/setup.py +++ b/setup.py @@ -18,8 +18,6 @@ from setuptools import find_packages from setuptools import setup -version = "0.2.0" - here = os.path.abspath(os.path.dirname(__file__)) try: README = open(os.path.join(here, "README.md"), encoding='utf-8').read() @@ -35,9 +33,8 @@ ] tests_require = [ - "gym", "jaxlib", - "opencv-python", + "ml-collections", "pytest", "pytest-cov", "pytest-xdist==1.34.0", # upgrading to 2.0 broke tests, need to investigate @@ -46,9 +43,14 @@ "tensorflow_datasets", ] +__version__ = None + +with open('flax/version.py') as f: + exec(f.read(), globals()) + setup( name="flax", - version=version, + version=__version__, description="Flax: A neural network library for JAX designed for flexibility", long_description="\n\n".join([README]), long_description_content_type='text/markdown', From e846aeff353ac49419b8f3e92f12ce2ff0d6e510 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Wed, 23 Sep 2020 15:32:36 +0000 Subject: [PATCH 34/48] Add required packages to test requirements --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 4a8a1e3762..43f387c877 100644 --- a/setup.py +++ b/setup.py @@ -33,8 +33,10 @@ ] tests_require = [ + "gym", "jaxlib", "ml-collections", + "opencv-python", "pytest", "pytest-cov", "pytest-xdist==1.34.0", # upgrading to 2.0 broke tests, need to investigate From 7b02ec079fce275d92302adf618775bea7ba564d Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Wed, 23 Sep 2020 17:04:53 +0000 Subject: [PATCH 35/48] Cleanup of main.py incl. variable rename --- examples/ppo/main.py | 77 +++++++++++++++++++++++++------------------- 1 file changed, 44 insertions(+), 33 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index b5c88cfbcc..525b8fc386 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -1,15 +1,15 @@ -import jax -import jax.random -import jax.numpy as jnp -import numpy as onp -import flax import time import functools import queue +from typing import Tuple, List +import threading from absl import flags from absl import app -import threading -from typing import Tuple, List +import jax +import jax.random +import jax.numpy as jnp +import numpy as onp +import flax import models import agent @@ -85,6 +85,17 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): As defined by eqs. (11-12) in PPO paper arXiv: 1707.06347. Implementation uses key observation that A_{t} = delta_t + gamma*lambda*A_{t+1}. + + Args: + rewards: array shaped (actor_steps, num_agents), rewards from the game + terminal_masks: array shaped (actor_steps, num_agents), zeros for terminal + and ones for non-terminal states + values: array shaped (actor_steps, num_agents), values estimated by critic + discount: RL discount usually denoted with gamma + gae_param: GAE parameter usually denoted with lambda + + Returns: + advantages: calculated advantages shaped (actor_steps, num_agents) """ assert rewards.shape[0] + 1 == values.shape[0], ("One more value needed; " "Eq. (12) in PPO paper requires V(s_{t+1}) to calculate delta_t") @@ -101,7 +112,13 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): return jnp.array(advantages) @jax.jit -def train_step(optimizer, trn_data, clip_param, vf_coeff, entropy_coeff, lr): +def train_step( + optimizer: flax.optim.base.Optimizer, + trajectories: Tuple[onp.array, onp.array, onp.array, onp.array, onp.array], + clip_param: float, + vf_coeff: float, + entropy_coeff: float, + lr: float): """Compilable train step. Runs an entire epoch of training (i.e. the loop over @@ -109,7 +126,7 @@ def train_step(optimizer, trn_data, clip_param, vf_coeff, entropy_coeff, lr): Args: optimizer: optimizer for the actor-critic model - trn_data: Tuple of the following five elements forming the experience: + trajectories: Tuple of the following five elements forming the experience: states: shape (steps_per_agent*num_agents, 84, 84, 4) actions: shape (steps_per_agent*num_agents, 84, 84, 4) old_log_probs: shape (steps_per_agent*num_agents, ) @@ -128,7 +145,6 @@ def train_step(optimizer, trn_data, clip_param, vf_coeff, entropy_coeff, lr): """ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): states, actions, old_log_probs, returns, advantages = minibatch - shapes = list(map(lambda x : x.shape, minibatch)) log_probs, values = model(states) values = values[:, 0] # convert shapes: (batch, 1) to (batch, ) probs = jnp.exp(log_probs) @@ -147,11 +163,11 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): return PPO_loss + vf_coeff*value_loss - entropy_coeff*entropy batch_size = FLAGS.batch_size - iterations = trn_data[0].shape[0] // batch_size - trn_data = jax.tree_map( - lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trn_data) + iterations = trajectories[0].shape[0] // batch_size + trajectories = jax.tree_map( + lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trajectories) loss = 0. - for batch in zip(*trn_data): + for batch in zip(*trajectories): grad_fn = jax.value_and_grad(loss_fn) l, grad = grad_fn(optimizer.target, batch, clip_param, vf_coeff, entropy_coeff) @@ -170,7 +186,7 @@ def thread_inference( Runs `steps_per_actor` time steps of the game for each of the `simulators`. """ - while(True): + while True: optimizer, step = policy_q.get() all_experience = [] for _ in range(steps_per_actor + 1): # +1 to get one more value @@ -187,14 +203,9 @@ def thread_inference( log_probs, values = jax.device_get((log_probs, values)) probs = onp.exp(onp.array(log_probs)) - # print("probs after onp conversion", probs) for i, sim in enumerate(simulators): - # probs[i] should sum up to 1, but there are float round errors - # if using jnp.array directly, it required division by probs[i].sum() - # better solutions can be thought of - # issue might be a result of the network using jnp.int 32, , not 64 - probabilities = probs[i] # / probs[i].sum() + probabilities = probs[i] action = onp.random.choice(probs.shape[1], p=probabilities) #in principle, one could avoid sending value and log prob back and forth sim.conn.send((action, values[i, 0], log_probs[i][action])) @@ -258,9 +269,8 @@ def train( policy_q.put((optimizer, step)) # perform PPO training - # experience is a list of list of tuples, here we preprocess this data to - # get required input for GAE and then for training - # initial version, needs improvement in terms of speed & readability + # all_experience is a list of list of namedtuples; preprocess this data to + # get required input (trajectories) for GAE and later for training all_experiences = experience_q.get() if s >= 0: #avoid training when there's no data yet obs_shape = (84, 84, 4) @@ -290,19 +300,21 @@ def train( FLAGS.gamma, FLAGS.lambda_) returns = advantages + values[:-1, :] # after preprocessing, concatenate data from all agents - trn_data = (states, actions, log_probs, returns, advantages) - trn_data = tuple(map( + trajectories = (states, actions, log_probs, returns, advantages) + trajectories = tuple(map( lambda x: onp.reshape( - x, (FLAGS.num_agents * FLAGS.actor_steps, ) + x.shape[2:]), trn_data)) + x, (FLAGS.num_agents * FLAGS.actor_steps,) + x.shape[2:]), + trajectories)) print(f"Step {s}: rewards variance {rewards.var()}") lr = FLAGS.learning_rate * alpha clip_param = FLAGS.clip_param * alpha - for e in range(FLAGS.num_epochs): #possibly compile this loop inside a jit - shapes = list(map(lambda x : x.shape, trn_data)) + for e in range(FLAGS.num_epochs): + shapes = list(map(lambda x: x.shape, trajectories)) permutation = onp.random.permutation(num_agents * FLAGS.actor_steps) - trn_data = tuple(map(lambda x: x[permutation], trn_data)) - optimizer, loss, last_iter_grad_norm = train_step(optimizer, trn_data, - clip_param, FLAGS.vf_coeff, FLAGS.entropy_coeff, lr) + trajectories = tuple(map(lambda x: x[permutation], trajectories)) + optimizer, loss, last_iter_grad_norm = train_step( + optimizer, trajectories, clip_param, FLAGS.vf_coeff, + FLAGS.entropy_coeff, lr) print(f"epoch {e} loss {loss} grad norm {last_iter_grad_norm}") #end of PPO training @@ -322,7 +334,6 @@ def main(argv): model = models.create_model(subkey, num_outputs=num_actions) optimizer = models.create_optimizer(model, learning_rate=FLAGS.learning_rate) del model - # jax.device_put(optimizer.target, device=train_device) train(optimizer, game, total_frames, num_agents, train_device, inference_device) From 50b2b792454832469eeeb38189352a5f53071f65 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Thu, 24 Sep 2020 16:35:04 +0000 Subject: [PATCH 36/48] Streamline training: use one thread, divide code into smaller chunks --- examples/ppo/main.py | 223 ++++++++++++++++++----------------------- examples/ppo/remote.py | 6 +- 2 files changed, 100 insertions(+), 129 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 525b8fc386..5179824ff7 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -1,8 +1,5 @@ -import time import functools -import queue from typing import Tuple, List -import threading from absl import flags from absl import app import jax @@ -101,11 +98,11 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): "Eq. (12) in PPO paper requires V(s_{t+1}) to calculate delta_t") advantages, gae = [], 0. for t in reversed(range(len(rewards))): - # masks to set next state value to 0 for terminal states + # Masks used to set next state value to 0 for terminal states. value_diff = discount * values[t + 1] * terminal_masks[t] - values[t] delta = rewards[t] + value_diff - # masks[t] to ensure that values before and after a terminal state - # are independent of each other + # Masks[t] used to ensure that values before and after a terminal state + # are independent of each other. gae = delta + discount * gae_param * terminal_masks[t] * gae advantages.append(gae) advantages = advantages[::-1] @@ -146,19 +143,17 @@ def train_step( def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): states, actions, old_log_probs, returns, advantages = minibatch log_probs, values = model(states) - values = values[:, 0] # convert shapes: (batch, 1) to (batch, ) + values = values[:, 0] # convert shapes: (batch, 1) to (batch, ) probs = jnp.exp(log_probs) entropy = jnp.sum(-probs*log_probs, axis=1).mean() log_probs_act_taken = jax.vmap(lambda lp, a: lp[a])(log_probs, actions) ratios = jnp.exp(log_probs_act_taken - old_log_probs) - # adv. normalization (following the OpenAI baselines) + # Advantage normalization (following the OpenAI baselines). advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) PG_loss = ratios * advantages clipped_loss = advantages * jax.lax.clamp(1. - clip_param, ratios, 1. + clip_param) - assert(PG_loss.shape == clipped_loss.shape) PPO_loss = -jnp.mean(jnp.minimum(PG_loss, clipped_loss), axis=0) - assert(values.shape == returns.shape) value_loss = jnp.mean(jnp.square(returns - values), axis=0) return PPO_loss + vf_coeff*value_loss - entropy_coeff*entropy @@ -176,87 +171,105 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): grad_norm = sum(jnp.square(g).sum() for g in jax.tree_leaves(grad)) return optimizer, loss, grad_norm - -def thread_inference( - policy_q: queue.Queue, - experience_q: queue.Queue, +def get_experience( + model: flax.optim.base.Optimizer, simulators: List[remote.RemoteSimulator], steps_per_actor: int): - """Worker function for a separate inference thread. + """Collect experience from agents. Runs `steps_per_actor` time steps of the game for each of the `simulators`. """ - while True: - optimizer, step = policy_q.get() - all_experience = [] - for _ in range(steps_per_actor + 1): # +1 to get one more value - # needed for GAE - states = [] - for sim in simulators: - state = sim.conn.recv() - states.append(state) - states = onp.concatenate(states, axis=0) - - # perform inference - # policy_optimizer, step = policy_q.get() - log_probs, values = agent.policy_action(optimizer.target, states) - log_probs, values = jax.device_get((log_probs, values)) - - probs = onp.exp(onp.array(log_probs)) - - for i, sim in enumerate(simulators): - probabilities = probs[i] - action = onp.random.choice(probs.shape[1], p=probabilities) - #in principle, one could avoid sending value and log prob back and forth - sim.conn.send((action, values[i, 0], log_probs[i][action])) - - # get experience from simulators - experiences = [] - for sim in simulators: - sample = sim.conn.recv() - experiences.append(sample) - all_experience.append(experiences) - - experience_q.put(all_experience) + all_experience = [] + # Range up to steps_per_actor + 1 to get one more value needed for GAE. + for _ in range(steps_per_actor + 1): + states = [] + for sim in simulators: + state = sim.conn.recv() + states.append(state) + states = onp.concatenate(states, axis=0) + log_probs, values = agent.policy_action(model, states) + log_probs, values = jax.device_get((log_probs, values)) + probs = onp.exp(onp.array(log_probs)) + for i, sim in enumerate(simulators): + probabilities = probs[i] + action = onp.random.choice(probs.shape[1], p=probabilities) + # In principle, one could avoid sending value and log prob back and forth. + sim.conn.send((action, values[i, 0], log_probs[i][action])) + experiences = [] + for sim in simulators: + sample = sim.conn.recv() + experiences.append(sample) + all_experience.append(experiences) + return all_experience + +def process_experience( + experience: List[List[remote.ExpTuple]], + actor_steps: int, + num_agents: int): + """Process experience for training, including advantage estimation. + Args: + experience: collected from agents in the form of nested lists/namedtuple + actor_steps: number of steps each agent has completed + num_agents: number of agents that collected experience + + Returns: + trajectories: trajectories readily accessible for `train_step()` function + """ + obs_shape = (84, 84, 4) + exp_dims = (actor_steps, num_agents) + values_dims = (actor_steps + 1, num_agents) + states = onp.zeros(exp_dims + obs_shape, dtype=onp.float32) + actions = onp.zeros(exp_dims, dtype=onp.int32) + rewards = onp.zeros(exp_dims, dtype=onp.float32) + values = onp.zeros(values_dims, dtype=onp.float32) + log_probs = onp.zeros(exp_dims, dtype=onp.float32) + dones = onp.zeros(exp_dims, dtype=onp.float32) + + for t in range(len(experience) - 1): # experience[-1] only for next_values + for agent_id, exp_agent in enumerate(experience[t]): + states[t, agent_id, ...] = exp_agent.state + actions[t, agent_id] = exp_agent.action + rewards[t, agent_id] = exp_agent.reward + values[t, agent_id] = exp_agent.value + log_probs[t, agent_id] = exp_agent.log_prob + # Dones need to be 0 for terminal states. + dones[t, agent_id] = float(not exp_agent.done) + for a in range(num_agents): + values[-1, a] = experience[-1][a].value + advantages = gae_advantages(rewards, dones, values, + FLAGS.gamma, FLAGS.lambda_) + returns = advantages + values[:-1, :] + # After preprocessing, concatenate data from all agents. + trajectories = (states, actions, log_probs, returns, advantages) + trajectories = tuple(map( + lambda x: onp.reshape( + x, (FLAGS.num_agents * FLAGS.actor_steps,) + x.shape[2:]), + trajectories)) + return trajectories def train( optimizer: flax.optim.base.Optimizer, game: str, steps_total: int, - num_agents: int, - train_device, - inference_device): + num_agents: int): """Main training loop. Args: optimizer: optimizer for the actor-critic model - game: string specifying the Atari game from Gym package + game: string specifying the Atari game from gym package steps total: total number of frames (env steps) to train on num_agents: number of separate processes with agents running the envs - train_device : device used for training - inference_device : device used for inference Returns: - None + optimizer: the trained optimizer """ simulators = [remote.RemoteSimulator(game) for i in range(num_agents)] - policy_q = queue.Queue(maxsize=1) - experience_q = queue.Queue(maxsize=1) - inference_thread = threading.Thread( - target=thread_inference, - args=(policy_q, experience_q, simulators, FLAGS.actor_steps), - daemon=True) - inference_thread.start() - t1 = time.time() loop_steps = steps_total // (num_agents * FLAGS.actor_steps) for s in range(loop_steps): + # Bookkeeping and testing. print(f"\n training loop step {s}") - #bookkeeping and testing - if (s + 1) % (10000 // (num_agents * FLAGS.actor_steps)) == 0: - print(f" Frames processed {s * num_agents * FLAGS.actor_steps}, " + - f"time elapsed {time.time() - t1}") - t1 = time.time() + if (s + 1) % (20000 // (num_agents * FLAGS.actor_steps)) == 0: test_episodes.policy_test(1, optimizer.target, game) @@ -264,61 +277,22 @@ def train( alpha = 1. - s/loop_steps else: alpha = 1. - # send the up-to-date policy model and current step to inference thread - step = s*num_agents - policy_q.put((optimizer, step)) - - # perform PPO training - # all_experience is a list of list of namedtuples; preprocess this data to - # get required input (trajectories) for GAE and later for training - all_experiences = experience_q.get() - if s >= 0: #avoid training when there's no data yet - obs_shape = (84, 84, 4) - exp_dims = (FLAGS.actor_steps, FLAGS.num_agents) - values_dims = (FLAGS.actor_steps + 1, FLAGS.num_agents) - states = onp.zeros(exp_dims + obs_shape, dtype=onp.float32) - actions = onp.zeros(exp_dims, dtype=onp.int32) - rewards = onp.zeros(exp_dims, dtype=onp.float32) - values = onp.zeros(values_dims, dtype=onp.float32) - log_probs = onp.zeros(exp_dims, dtype=onp.float32) - dones = onp.zeros(exp_dims, dtype=onp.float32) - - assert(len(all_experiences[0]) == FLAGS.num_agents) - for t in range(len(all_experiences) - 1): #last only for next_values - for agent_id, exp_agent in enumerate(all_experiences[t]): - states[t, agent_id, ...] = exp_agent.state - actions[t, agent_id] = exp_agent.action - rewards[t, agent_id] = exp_agent.reward - values[t, agent_id] = exp_agent.value - log_probs[t, agent_id] = exp_agent.log_prob - # dones need to be 0 for terminal states - dones[t, agent_id] = float(not exp_agent.done) - for a in range(num_agents): - values[-1, a] = all_experiences[-1][a].value - # calculate advantages w. GAE - advantages = gae_advantages(rewards, dones, values, - FLAGS.gamma, FLAGS.lambda_) - returns = advantages + values[:-1, :] - # after preprocessing, concatenate data from all agents - trajectories = (states, actions, log_probs, returns, advantages) - trajectories = tuple(map( - lambda x: onp.reshape( - x, (FLAGS.num_agents * FLAGS.actor_steps,) + x.shape[2:]), - trajectories)) - print(f"Step {s}: rewards variance {rewards.var()}") - lr = FLAGS.learning_rate * alpha - clip_param = FLAGS.clip_param * alpha - for e in range(FLAGS.num_epochs): - shapes = list(map(lambda x: x.shape, trajectories)) - permutation = onp.random.permutation(num_agents * FLAGS.actor_steps) - trajectories = tuple(map(lambda x: x[permutation], trajectories)) - optimizer, loss, last_iter_grad_norm = train_step( - optimizer, trajectories, clip_param, FLAGS.vf_coeff, - FLAGS.entropy_coeff, lr) - print(f"epoch {e} loss {loss} grad norm {last_iter_grad_norm}") - #end of PPO training - - return None + + # Core training code. + all_experiences = get_experience(optimizer.target, simulators, + FLAGS.actor_steps) + trajectories = process_experience(all_experiences, FLAGS.actor_steps, + FLAGS.num_agents) + lr = FLAGS.learning_rate * alpha + clip_param = FLAGS.clip_param * alpha + for e in range(FLAGS.num_epochs): + permutation = onp.random.permutation(num_agents * FLAGS.actor_steps) + trajectories = tuple(map(lambda x: x[permutation], trajectories)) + optimizer, loss, last_iter_grad_norm = train_step( + optimizer, trajectories, clip_param, FLAGS.vf_coeff, + FLAGS.entropy_coeff, lr) + print(f"epoch {e} loss {loss} grad norm {last_iter_grad_norm}") + return optimizer def main(argv): game = "Pong" @@ -327,15 +301,12 @@ def main(argv): print(f"Playing {game} with {num_actions} actions") num_agents = FLAGS.num_agents total_frames = 40000000 - train_device = jax.devices()[0] - inference_device = jax.devices()[1] key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) model = models.create_model(subkey, num_outputs=num_actions) optimizer = models.create_optimizer(model, learning_rate=FLAGS.learning_rate) del model - train(optimizer, game, total_frames, num_agents, train_device, - inference_device) + optimizer = train(optimizer, game, total_frames, num_agents) if __name__ == '__main__': app.run(main) diff --git a/examples/ppo/remote.py b/examples/ppo/remote.py index b6494e87b2..8625f05f4a 100644 --- a/examples/ppo/remote.py +++ b/examples/ppo/remote.py @@ -6,8 +6,8 @@ import env_utils -exp_tuple = collections.namedtuple( - 'exp_tuple', ['state', 'action', 'reward', 'value', 'log_prob', 'done']) +ExpTuple = collections.namedtuple( + 'ExpTuple', ['state', 'action', 'reward', 'value', 'log_prob', 'done']) class RemoteSimulator: """Wrap functionality for an agent emulating Atari in a separate process. @@ -39,7 +39,7 @@ def rcv_action_send_exp(conn, game): action, value, log_prob = conn.recv() obs, reward, done, _ = env.step(action) next_state = get_state(obs) if not done else None - experience = exp_tuple(state, action, reward, value, log_prob, done) + experience = ExpTuple(state, action, reward, value, log_prob, done) conn.send(experience) if done: break From df3daa193a319c29eaf9299708ede11d37c0ffe6 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Thu, 24 Sep 2020 17:23:31 +0000 Subject: [PATCH 37/48] Avoid using global variables --- examples/ppo/main.py | 50 +++++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/examples/ppo/main.py b/examples/ppo/main.py index 5179824ff7..a84ce1050f 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/main.py @@ -95,7 +95,7 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): advantages: calculated advantages shaped (actor_steps, num_agents) """ assert rewards.shape[0] + 1 == values.shape[0], ("One more value needed; " - "Eq. (12) in PPO paper requires V(s_{t+1}) to calculate delta_t") + "Eq. (12) in PPO paper requires V(s_{t+1}) to calculate delta_t") advantages, gae = [], 0. for t in reversed(range(len(rewards))): # Masks used to set next state value to 0 for terminal states. @@ -108,14 +108,15 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): advantages = advantages[::-1] return jnp.array(advantages) -@jax.jit +@functools.partial(jax.jit, static_argnums=(6)) def train_step( optimizer: flax.optim.base.Optimizer, trajectories: Tuple[onp.array, onp.array, onp.array, onp.array, onp.array], clip_param: float, vf_coeff: float, entropy_coeff: float, - lr: float): + lr: float, + batch_size: int): """Compilable train step. Runs an entire epoch of training (i.e. the loop over @@ -134,6 +135,7 @@ def train_step( entropy_coeff: weighs entropy bonus in the total loss lr: learning rate, varies between optimization steps if decaying_lr_and_clip_param is set to true + batch_size: the minibatch size, static argument Returns: optimizer: new optimizer after the parameters update @@ -157,7 +159,6 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): value_loss = jnp.mean(jnp.square(returns - values), axis=0) return PPO_loss + vf_coeff*value_loss - entropy_coeff*entropy - batch_size = FLAGS.batch_size iterations = trajectories[0].shape[0] // batch_size trajectories = jax.tree_map( lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trajectories) @@ -205,13 +206,17 @@ def get_experience( def process_experience( experience: List[List[remote.ExpTuple]], actor_steps: int, - num_agents: int): + num_agents: int, + gamma: float, + lambda_: float): """Process experience for training, including advantage estimation. Args: experience: collected from agents in the form of nested lists/namedtuple actor_steps: number of steps each agent has completed num_agents: number of agents that collected experience + gamma: dicount parameter + lambda_: GAE parameter Returns: trajectories: trajectories readily accessible for `train_step()` function @@ -237,14 +242,13 @@ def process_experience( dones[t, agent_id] = float(not exp_agent.done) for a in range(num_agents): values[-1, a] = experience[-1][a].value - advantages = gae_advantages(rewards, dones, values, - FLAGS.gamma, FLAGS.lambda_) + advantages = gae_advantages(rewards, dones, values, gamma, lambda_) returns = advantages + values[:-1, :] # After preprocessing, concatenate data from all agents. trajectories = (states, actions, log_probs, returns, advantages) trajectories = tuple(map( lambda x: onp.reshape( - x, (FLAGS.num_agents * FLAGS.actor_steps,) + x.shape[2:]), + x, (num_agents * actor_steps,) + x.shape[2:]), trajectories)) return trajectories @@ -252,7 +256,8 @@ def train( optimizer: flax.optim.base.Optimizer, game: str, steps_total: int, - num_agents: int): + num_agents: int, + flags_: flags._flagvalues.FlagValues): """Main training loop. Args: @@ -265,32 +270,33 @@ def train( optimizer: the trained optimizer """ simulators = [remote.RemoteSimulator(game) for i in range(num_agents)] - loop_steps = steps_total // (num_agents * FLAGS.actor_steps) + loop_steps = steps_total // (num_agents * flags_.actor_steps) for s in range(loop_steps): # Bookkeeping and testing. print(f"\n training loop step {s}") - if (s + 1) % (20000 // (num_agents * FLAGS.actor_steps)) == 0: + if (s + 1) % (20000 // (num_agents * flags_.actor_steps)) == 0: test_episodes.policy_test(1, optimizer.target, game) - if FLAGS.decaying_lr_and_clip_param: + if flags_.decaying_lr_and_clip_param: alpha = 1. - s/loop_steps else: alpha = 1. # Core training code. all_experiences = get_experience(optimizer.target, simulators, - FLAGS.actor_steps) - trajectories = process_experience(all_experiences, FLAGS.actor_steps, - FLAGS.num_agents) - lr = FLAGS.learning_rate * alpha - clip_param = FLAGS.clip_param * alpha - for e in range(FLAGS.num_epochs): - permutation = onp.random.permutation(num_agents * FLAGS.actor_steps) + flags_.actor_steps) + trajectories = process_experience( + all_experiences, flags_.actor_steps, flags_.num_agents, flags_.gamma, + flags_.lambda_) + lr = flags_.learning_rate * alpha + clip_param = flags_.clip_param * alpha + for e in range(flags_.num_epochs): + permutation = onp.random.permutation(num_agents * flags_.actor_steps) trajectories = tuple(map(lambda x: x[permutation], trajectories)) optimizer, loss, last_iter_grad_norm = train_step( - optimizer, trajectories, clip_param, FLAGS.vf_coeff, - FLAGS.entropy_coeff, lr) + optimizer, trajectories, clip_param, flags_.vf_coeff, + flags_.entropy_coeff, lr, flags_.batch_size) print(f"epoch {e} loss {loss} grad norm {last_iter_grad_norm}") return optimizer @@ -306,7 +312,7 @@ def main(argv): model = models.create_model(subkey, num_outputs=num_actions) optimizer = models.create_optimizer(model, learning_rate=FLAGS.learning_rate) del model - optimizer = train(optimizer, game, total_frames, num_agents) + optimizer = train(optimizer, game, total_frames, num_agents, FLAGS) if __name__ == '__main__': app.run(main) From 7e036ae59eafc4b0062c204921c9b6d7b2bc0597 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Thu, 24 Sep 2020 17:42:10 +0000 Subject: [PATCH 38/48] Adhere to file naming standard --- examples/ppo/{main.py => ppo_lib.py} | 85 +----------------- .../ppo/{unit_tests.py => ppo_lib_test.py} | 13 +-- examples/ppo/ppo_main.py | 86 +++++++++++++++++++ setup.py | 1 + 4 files changed, 97 insertions(+), 88 deletions(-) rename examples/ppo/{main.py => ppo_lib.py} (82%) rename examples/ppo/{unit_tests.py => ppo_lib_test.py} (93%) create mode 100644 examples/ppo/ppo_main.py diff --git a/examples/ppo/main.py b/examples/ppo/ppo_lib.py similarity index 82% rename from examples/ppo/main.py rename to examples/ppo/ppo_lib.py index a84ce1050f..3a275993fb 100644 --- a/examples/ppo/main.py +++ b/examples/ppo/ppo_lib.py @@ -1,79 +1,17 @@ +"""Library file which executes the PPO training""" + import functools from typing import Tuple, List from absl import flags -from absl import app import jax import jax.random import jax.numpy as jnp import numpy as onp import flax -import models import agent import remote import test_episodes -import env_utils - -FLAGS = flags.FLAGS - -# default hyperparameters taken from PPO paper and openAI baselines 2 -# https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py - -flags.DEFINE_float( - 'learning_rate', default=2.5e-4, - help=('The learning rate for the Adam optimizer.') -) - -flags.DEFINE_integer( - 'batch_size', default=256, - help=('Batch size for training.') -) - -flags.DEFINE_integer( - 'num_agents', default=8, - help=('Number of agents playing in parallel.') -) - -flags.DEFINE_integer( - 'actor_steps', default=128, - help=('Batch size for training.') -) - -flags.DEFINE_integer( - 'num_epochs', default=3, - help=('Number of epochs per each unroll of the policy.') -) - -flags.DEFINE_float( - 'gamma', default=0.99, - help=('Discount parameter.') -) - -flags.DEFINE_float( - 'lambda_', default=0.95, - help=('Generalized Advantage Estimation parameter.') -) - -flags.DEFINE_float( - 'clip_param', default=0.1, - help=('The PPO clipping parameter used to clamp ratios in loss function.') -) - -flags.DEFINE_float( - 'vf_coeff', default=0.5, - help=('Weighs value function loss in the total loss.') -) - -flags.DEFINE_float( - 'entropy_coeff', default=0.01, - help=('Weighs entropy bonus in the total loss.') -) - -flags.DEFINE_boolean( - 'decaying_lr_and_clip_param', default=True, - help=(('Linearly decay learning rate and clipping parameter to zero during ' - 'the training.')) -) @functools.partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1) @jax.jit @@ -298,21 +236,4 @@ def train( optimizer, trajectories, clip_param, flags_.vf_coeff, flags_.entropy_coeff, lr, flags_.batch_size) print(f"epoch {e} loss {loss} grad norm {last_iter_grad_norm}") - return optimizer - -def main(argv): - game = "Pong" - game += "NoFrameskip-v4" - num_actions = env_utils.get_num_actions(game) - print(f"Playing {game} with {num_actions} actions") - num_agents = FLAGS.num_agents - total_frames = 40000000 - key = jax.random.PRNGKey(0) - key, subkey = jax.random.split(key) - model = models.create_model(subkey, num_outputs=num_actions) - optimizer = models.create_optimizer(model, learning_rate=FLAGS.learning_rate) - del model - optimizer = train(optimizer, game, total_frames, num_agents, FLAGS) - -if __name__ == '__main__': - app.run(main) + return optimizer \ No newline at end of file diff --git a/examples/ppo/unit_tests.py b/examples/ppo/ppo_lib_test.py similarity index 93% rename from examples/ppo/unit_tests.py rename to examples/ppo/ppo_lib_test.py index 6bf94d3341..11e8262dcc 100644 --- a/examples/ppo/unit_tests.py +++ b/examples/ppo/ppo_lib_test.py @@ -5,7 +5,7 @@ import numpy.testing as onp_testing from absl.testing import absltest -import main +import ppo_lib import env_utils import models @@ -20,7 +20,7 @@ def test_gae_shape_on_random(self): values = onp.random.random(size=(steps + 1, envs)) discount = 0.99 gae_param = 0.95 - adv = main.gae_advantages(rewards, terminal_masks, values, discount, + adv = ppo_lib.gae_advantages(rewards, terminal_masks, values, discount, gae_param) self.assertEqual(adv.shape, (steps, envs)) def test_gae_hardcoded(self): @@ -32,8 +32,8 @@ def test_gae_hardcoded(self): discount = 0.5 gae_param = 0.25 correct_gae = onp.array([[0.375, -0.5546875], [-1., -0.4375], [-1.5, 0.5]]) - actual_gae = main.gae_advantages(rewards, terminal_masks, values, discount, - gae_param) + actual_gae = ppo_lib.gae_advantages(rewards, terminal_masks, values, + discount, gae_param) onp_testing.assert_allclose(actual_gae, correct_gae) # test environment and preprocessing class TestEnvironmentPreprocessing(absltest.TestCase): @@ -104,12 +104,13 @@ def test_optimization_step(self): vf_coeff = 0.5 entropy_coeff = 0.01 lr = 2.5e-4 + batch_size = 256 key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) model = models.create_model(subkey, num_outputs) optimizer = models.create_optimizer(model, learning_rate=lr) - optimizer, _, _ = main.train_step( - optimizer, trn_data, clip_param, vf_coeff, entropy_coeff, lr) + optimizer, _, _ = ppo_lib.train_step( + optimizer, trn_data, clip_param, vf_coeff, entropy_coeff, lr, batch_size) self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer)) if __name__ == '__main__': diff --git a/examples/ppo/ppo_main.py b/examples/ppo/ppo_main.py new file mode 100644 index 0000000000..4a96bbc5c2 --- /dev/null +++ b/examples/ppo/ppo_main.py @@ -0,0 +1,86 @@ +from absl import flags +from absl import app +import jax +import jax.random + +import ppo_lib +import models +import env_utils + +FLAGS = flags.FLAGS + +# default hyperparameters taken from PPO paper and openAI baselines 2 +# https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py + +flags.DEFINE_float( + 'learning_rate', default=2.5e-4, + help=('The learning rate for the Adam optimizer.') +) + +flags.DEFINE_integer( + 'batch_size', default=256, + help=('Batch size for training.') +) + +flags.DEFINE_integer( + 'num_agents', default=8, + help=('Number of agents playing in parallel.') +) + +flags.DEFINE_integer( + 'actor_steps', default=128, + help=('Batch size for training.') +) + +flags.DEFINE_integer( + 'num_epochs', default=3, + help=('Number of epochs per each unroll of the policy.') +) + +flags.DEFINE_float( + 'gamma', default=0.99, + help=('Discount parameter.') +) + +flags.DEFINE_float( + 'lambda_', default=0.95, + help=('Generalized Advantage Estimation parameter.') +) + +flags.DEFINE_float( + 'clip_param', default=0.1, + help=('The PPO clipping parameter used to clamp ratios in loss function.') +) + +flags.DEFINE_float( + 'vf_coeff', default=0.5, + help=('Weighs value function loss in the total loss.') +) + +flags.DEFINE_float( + 'entropy_coeff', default=0.01, + help=('Weighs entropy bonus in the total loss.') +) + +flags.DEFINE_boolean( + 'decaying_lr_and_clip_param', default=True, + help=(('Linearly decay learning rate and clipping parameter to zero during ' + 'the training.')) +) + +def main(argv): + game = "Pong" + game += "NoFrameskip-v4" + num_actions = env_utils.get_num_actions(game) + print(f"Playing {game} with {num_actions} actions") + num_agents = FLAGS.num_agents + total_frames = 40000000 + key = jax.random.PRNGKey(0) + key, subkey = jax.random.split(key) + model = models.create_model(subkey, num_outputs=num_actions) + optimizer = models.create_optimizer(model, learning_rate=FLAGS.learning_rate) + del model + optimizer = ppo_lib.train(optimizer, game, total_frames, num_agents, FLAGS) + +if __name__ == '__main__': + app.run(main) diff --git a/setup.py b/setup.py index 43f387c877..6846acfa4a 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ ] tests_require = [ + "atari-py", "gym", "jaxlib", "ml-collections", From 9ff33b97091cc74620563934d6d526b0248fa11f Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Fri, 25 Sep 2020 08:10:25 +0000 Subject: [PATCH 39/48] Merge remote.py with agent.py due to similar function --- examples/ppo/agent.py | 55 +++++++++++++++++++++++++++++++++-- examples/ppo/ppo_lib.py | 7 ++--- examples/ppo/remote.py | 52 --------------------------------- examples/ppo/test_episodes.py | 5 ++-- 4 files changed, 57 insertions(+), 62 deletions(-) delete mode 100644 examples/ppo/remote.py diff --git a/examples/ppo/agent.py b/examples/ppo/agent.py index 6a75b69e1d..3d75afb10b 100644 --- a/examples/ppo/agent.py +++ b/examples/ppo/agent.py @@ -1,12 +1,61 @@ -"""Agent's moves (forward pass of the network).""" +"""Agent utilities, incl. choosing the move and running in separate process.""" +import multiprocessing +import collections import jax import numpy as onp +import env_utils + @jax.jit def policy_action(model, state): """Forward pass of the network.""" - # Potentially the random choice of the action from probabilities can be moved - # here with additional rng_key parameter. out = model(state) return out + + + +ExpTuple = collections.namedtuple( + 'ExpTuple', ['state', 'action', 'reward', 'value', 'log_prob', 'done']) + + +class RemoteSimulator: + """Wrap functionality for an agent emulating Atari in a separate process. + + An object of this class is created for every agent. + """ + + def __init__(self, game): + parent_conn, child_conn = multiprocessing.Pipe() + self.proc = multiprocessing.Process( + target=rcv_action_send_exp, args=(child_conn, game)) + self.conn = parent_conn + self.proc.start() + + +def rcv_action_send_exp(conn, game): + """Run the remote agents. + + Receive action from the main learner, perform one step of simulation and + send back collected experience. + """ + env = env_utils.create_env(game, clip_rewards=True) + while True: + obs = env.reset() + done = False + state = get_state(obs) + while not done: + conn.send(state) + action, value, log_prob = conn.recv() + obs, reward, done, _ = env.step(action) + next_state = get_state(obs) if not done else None + experience = ExpTuple(state, action, reward, value, log_prob, done) + conn.send(experience) + if done: + break + state = next_state + +def get_state(observation): + """Convert Atari env observation into a NumPy array, add batch dimension.""" + state = onp.array(observation) + return state[None, ...] diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index 3a275993fb..2ff50bb64b 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -10,7 +10,6 @@ import flax import agent -import remote import test_episodes @functools.partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1) @@ -112,7 +111,7 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): def get_experience( model: flax.optim.base.Optimizer, - simulators: List[remote.RemoteSimulator], + simulators: List[agent.RemoteSimulator], steps_per_actor: int): """Collect experience from agents. @@ -142,7 +141,7 @@ def get_experience( return all_experience def process_experience( - experience: List[List[remote.ExpTuple]], + experience: List[List[agent.ExpTuple]], actor_steps: int, num_agents: int, gamma: float, @@ -207,7 +206,7 @@ def train( Returns: optimizer: the trained optimizer """ - simulators = [remote.RemoteSimulator(game) for i in range(num_agents)] + simulators = [agent.RemoteSimulator(game) for i in range(num_agents)] loop_steps = steps_total // (num_agents * flags_.actor_steps) for s in range(loop_steps): # Bookkeeping and testing. diff --git a/examples/ppo/remote.py b/examples/ppo/remote.py deleted file mode 100644 index 8625f05f4a..0000000000 --- a/examples/ppo/remote.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Utilities for running the agents in separate processes.""" - -import multiprocessing -import numpy as onp -import collections - -import env_utils - -ExpTuple = collections.namedtuple( - 'ExpTuple', ['state', 'action', 'reward', 'value', 'log_prob', 'done']) - -class RemoteSimulator: - """Wrap functionality for an agent emulating Atari in a separate process. - - An object of this class is created for every agent. - """ - - def __init__(self, game): - parent_conn, child_conn = multiprocessing.Pipe() - self.proc = multiprocessing.Process( - target=rcv_action_send_exp, args=(child_conn, game)) - self.conn = parent_conn - self.proc.start() - - -def rcv_action_send_exp(conn, game): - """Run the remote agents. - - Receive action from the main learner, perform one step of simulation and - send back collected experience. - """ - env = env_utils.create_env(game, clip_rewards=True) - while True: - obs = env.reset() - done = False - state = get_state(obs) - while not done: - conn.send(state) - action, value, log_prob = conn.recv() - obs, reward, done, _ = env.step(action) - next_state = get_state(obs) if not done else None - experience = ExpTuple(state, action, reward, value, log_prob, done) - conn.send(experience) - if done: - break - state = next_state - - -def get_state(observation): - """Convert Atari env observation into a NumPy array, add batch dimension.""" - state = onp.array(observation) - return state[None, ...] diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py index 4c8aa76e27..11fe3cd4c4 100644 --- a/examples/ppo/test_episodes.py +++ b/examples/ppo/test_episodes.py @@ -5,7 +5,6 @@ import numpy as onp import env_utils -import remote import agent def policy_test( @@ -26,7 +25,7 @@ def policy_test( all_probabilities = [] for _ in range(n_episodes): obs = test_env.reset() - state = remote.get_state(obs) + state = agent.get_state(obs) total_reward = 0.0 for t in itertools.count(): log_probs, _ = agent.policy_action(model, state) @@ -37,7 +36,7 @@ def policy_test( obs, reward, done, _ = test_env.step(action) total_reward += reward if not done: - next_state = remote.get_state(obs) + next_state = agent.get_state(obs) else: next_state = None state = next_state From 08bd3449598e1f33ae11c5c67f5974dacbdd543a Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Fri, 25 Sep 2020 13:41:42 +0000 Subject: [PATCH 40/48] Use tensorboard for logging and add checkpointing --- examples/ppo/ppo_lib.py | 33 ++++++++++++++++++--------------- examples/ppo/ppo_main.py | 12 +++++++++++- examples/ppo/test_episodes.py | 15 +++------------ 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index 2ff50bb64b..4bc0b3450a 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -8,6 +8,8 @@ import jax.numpy as jnp import numpy as onp import flax +from flax.metrics import tensorboard +from flax.training import checkpoints import agent import test_episodes @@ -193,7 +195,6 @@ def train( optimizer: flax.optim.base.Optimizer, game: str, steps_total: int, - num_agents: int, flags_: flags._flagvalues.FlagValues): """Main training loop. @@ -201,26 +202,28 @@ def train( optimizer: optimizer for the actor-critic model game: string specifying the Atari game from gym package steps total: total number of frames (env steps) to train on - num_agents: number of separate processes with agents running the envs Returns: optimizer: the trained optimizer """ - simulators = [agent.RemoteSimulator(game) for i in range(num_agents)] - loop_steps = steps_total // (num_agents * flags_.actor_steps) + simulators = [agent.RemoteSimulator(game) for i in range(flags_.num_agents)] + model_dir = '/tmp/ppo_training/' + summary_writer = tensorboard.SummaryWriter(model_dir) + loop_steps = steps_total // (flags_.num_agents * flags_.actor_steps) + log_frequency = 40 + checkpoint_frequency = 200 for s in range(loop_steps): # Bookkeeping and testing. - print(f"\n training loop step {s}") - - if (s + 1) % (20000 // (num_agents * flags_.actor_steps)) == 0: - test_episodes.policy_test(1, optimizer.target, game) - - if flags_.decaying_lr_and_clip_param: - alpha = 1. - s/loop_steps - else: - alpha = 1. + if s % log_frequency == 0: + score = test_episodes.policy_test(1, optimizer.target, game) + frames = s * flags_.num_agents * flags_.actor_steps + summary_writer.scalar('game_score', score, frames) + print(f'Step {s}:\nframes seen {frames}\nscore {score}\n\n') + if s % checkpoint_frequency == 0: + checkpoints.save_checkpoint(model_dir, optimizer, s) # Core training code. + alpha = 1. - s/loop_steps if flags_.decaying_lr_and_clip_param else 1. all_experiences = get_experience(optimizer.target, simulators, flags_.actor_steps) trajectories = process_experience( @@ -229,10 +232,10 @@ def train( lr = flags_.learning_rate * alpha clip_param = flags_.clip_param * alpha for e in range(flags_.num_epochs): - permutation = onp.random.permutation(num_agents * flags_.actor_steps) + permutation = onp.random.permutation( + flags_.num_agents * flags_.actor_steps) trajectories = tuple(map(lambda x: x[permutation], trajectories)) optimizer, loss, last_iter_grad_norm = train_step( optimizer, trajectories, clip_param, flags_.vf_coeff, flags_.entropy_coeff, lr, flags_.batch_size) - print(f"epoch {e} loss {loss} grad norm {last_iter_grad_norm}") return optimizer \ No newline at end of file diff --git a/examples/ppo/ppo_main.py b/examples/ppo/ppo_main.py index 4a96bbc5c2..f5130c50df 100644 --- a/examples/ppo/ppo_main.py +++ b/examples/ppo/ppo_main.py @@ -68,6 +68,16 @@ 'the training.')) ) +flags.DEFINE_string( + 'game', default='Pong', + help=('The Atari game used.') +) + +flags.DEFINE_string( + 'logdir', default='/tmp/ppo_training', + help=('Directory to set .') +) + def main(argv): game = "Pong" game += "NoFrameskip-v4" @@ -80,7 +90,7 @@ def main(argv): model = models.create_model(subkey, num_outputs=num_actions) optimizer = models.create_optimizer(model, learning_rate=FLAGS.learning_rate) del model - optimizer = ppo_lib.train(optimizer, game, total_frames, num_agents, FLAGS) + optimizer = ppo_lib.train(optimizer, game, total_frames, FLAGS) if __name__ == '__main__': app.run(main) diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py index 11fe3cd4c4..c550d111f5 100644 --- a/examples/ppo/test_episodes.py +++ b/examples/ppo/test_episodes.py @@ -7,10 +7,7 @@ import env_utils import agent -def policy_test( - n_episodes: int, - model: flax.nn.base.Model, - game: str): +def policy_test(n_episodes: int, model: flax.nn.base.Model, game: str): """Perform a test of the policy in Atari environment. Args: @@ -19,10 +16,9 @@ def policy_test( game: defines the Atari game to test on Returns: - None + total_reward: obtained score """ test_env = env_utils.create_env(game, clip_rewards=False) - all_probabilities = [] for _ in range(n_episodes): obs = test_env.reset() state = agent.get_state(obs) @@ -31,7 +27,6 @@ def policy_test( log_probs, _ = agent.policy_action(model, state) probs = onp.exp(onp.array(log_probs, dtype=onp.float32)) probabilities = probs[0] / probs[0].sum() - all_probabilities.append(probabilities) action = onp.random.choice(probs.shape[1], p=probabilities) obs, reward, done, _ = test_env.step(action) total_reward += reward @@ -41,9 +36,5 @@ def policy_test( next_state = None state = next_state if done: - all_probabilities = onp.stack(all_probabilities, axis=0) - vars = onp.var(all_probabilities, axis=0) - print(f"------> TEST FINISHED: reward {total_reward} in {t} steps") - print(f"Variance of probabilities across encuntered states {vars}") break - del test_env + return total_reward From 65faed8dd15f7bc816a5fb8e26d503e199098b58 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Mon, 28 Sep 2020 13:29:59 +0000 Subject: [PATCH 41/48] Simplify and format code --- examples/ppo/agent.py | 5 +-- examples/ppo/env_utils.py | 7 +-- examples/ppo/models.py | 6 +-- examples/ppo/ppo_lib.py | 85 ++++++++++++++++++----------------- examples/ppo/ppo_lib_test.py | 9 ++-- examples/ppo/ppo_main.py | 70 +++++++++++++++-------------- examples/ppo/test_episodes.py | 5 +-- 7 files changed, 95 insertions(+), 92 deletions(-) diff --git a/examples/ppo/agent.py b/examples/ppo/agent.py index 3d75afb10b..bed3bb0587 100644 --- a/examples/ppo/agent.py +++ b/examples/ppo/agent.py @@ -14,9 +14,8 @@ def policy_action(model, state): return out - ExpTuple = collections.namedtuple( - 'ExpTuple', ['state', 'action', 'reward', 'value', 'log_prob', 'done']) + 'ExpTuple', ['state', 'action', 'reward', 'value', 'log_prob', 'done']) class RemoteSimulator: @@ -28,7 +27,7 @@ class RemoteSimulator: def __init__(self, game): parent_conn, child_conn = multiprocessing.Pipe() self.proc = multiprocessing.Process( - target=rcv_action_send_exp, args=(child_conn, game)) + target=rcv_action_send_exp, args=(child_conn, game)) self.conn = parent_conn self.proc.start() diff --git a/examples/ppo/env_utils.py b/examples/ppo/env_utils.py index 0150f0a5fe..fb7a6909d2 100644 --- a/examples/ppo/env_utils.py +++ b/examples/ppo/env_utils.py @@ -25,9 +25,10 @@ class FrameStack: Wraps an AtariPreprocessing object. """ - def __init__(self, - preproc: seed_rl_atari_preprocessing.AtariPreprocessing, - num_frames: int): + def __init__( + self, + preproc: seed_rl_atari_preprocessing.AtariPreprocessing, + num_frames: int): self.preproc = preproc self.num_frames = num_frames self.frames = collections.deque(maxlen=num_frames) diff --git a/examples/ppo/models.py b/examples/ppo/models.py index 8708d9dd87..19b38a4311 100644 --- a/examples/ppo/models.py +++ b/examples/ppo/models.py @@ -32,15 +32,15 @@ def apply(self, x, num_outputs): x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=512, name='hidden', dtype=dtype) x = nn.relu(x) - # network used to both estimate policy (logits) and expected state value - # see github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py + # Network used to both estimate policy (logits) and expected state value. + # See github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py logits = nn.Dense(x, features=num_outputs, name='logits', dtype=dtype) policy_log_probabilities = nn.log_softmax(logits) value = nn.Dense(x, features=1, name='value', dtype=dtype) return policy_log_probabilities, value def create_model(key, num_outputs): - input_dims = (1, 84, 84, 4) #(minibatch, height, width, stacked frames) + input_dims = (1, 84, 84, 4) # (minibatch, height, width, stacked frames) module = ActorCritic.partial(num_outputs=num_outputs) _, initial_par = module.init_by_shape(key, [(input_dims, jnp.float32)]) model = flax.nn.Model(module, initial_par) diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index 4bc0b3450a..7369c3e81e 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -8,6 +8,7 @@ import jax.numpy as jnp import numpy as onp import flax +from flax import nn from flax.metrics import tensorboard from flax.training import checkpoints @@ -33,9 +34,11 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): Returns: advantages: calculated advantages shaped (actor_steps, num_agents) """ - assert rewards.shape[0] + 1 == values.shape[0], ("One more value needed; " - "Eq. (12) in PPO paper requires V(s_{t+1}) to calculate delta_t") - advantages, gae = [], 0. + assert rewards.shape[0] + 1 == values.shape[0], ('One more value needed; Eq. ' + '(12) in PPO paper requires ' + 'V(s_{t+1}) for delta_t') + advantages = [] + gae = 0. for t in reversed(range(len(rewards))): # Masks used to set next state value to 0 for terminal states. value_diff = discount * values[t + 1] * terminal_masks[t] - values[t] @@ -49,13 +52,13 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): @functools.partial(jax.jit, static_argnums=(6)) def train_step( - optimizer: flax.optim.base.Optimizer, - trajectories: Tuple[onp.array, onp.array, onp.array, onp.array, onp.array], - clip_param: float, - vf_coeff: float, - entropy_coeff: float, - lr: float, - batch_size: int): + optimizer: flax.optim.base.Optimizer, + trajectories: Tuple[onp.array, onp.array, onp.array, onp.array, onp.array], + clip_param: float, + vf_coeff: float, + entropy_coeff: float, + lr: float, + batch_size: int): """Compilable train step. Runs an entire epoch of training (i.e. the loop over @@ -79,12 +82,11 @@ def train_step( Returns: optimizer: new optimizer after the parameters update loss: loss summed over training steps - grad_norm: gradient norm from last step (summed over parameters) """ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): states, actions, old_log_probs, returns, advantages = minibatch log_probs, values = model(states) - values = values[:, 0] # convert shapes: (batch, 1) to (batch, ) + values = values[:, 0] # Convert shapes: (batch, 1) to (batch, ). probs = jnp.exp(log_probs) entropy = jnp.sum(-probs*log_probs, axis=1).mean() log_probs_act_taken = jax.vmap(lambda lp, a: lp[a])(log_probs, actions) @@ -93,14 +95,14 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) PG_loss = ratios * advantages clipped_loss = advantages * jax.lax.clamp(1. - clip_param, ratios, - 1. + clip_param) + 1. + clip_param) PPO_loss = -jnp.mean(jnp.minimum(PG_loss, clipped_loss), axis=0) value_loss = jnp.mean(jnp.square(returns - values), axis=0) return PPO_loss + vf_coeff*value_loss - entropy_coeff*entropy iterations = trajectories[0].shape[0] // batch_size trajectories = jax.tree_map( - lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trajectories) + lambda x: x.reshape((iterations, batch_size) + x.shape[1:]), trajectories) loss = 0. for batch in zip(*trajectories): grad_fn = jax.value_and_grad(loss_fn) @@ -108,13 +110,12 @@ def loss_fn(model, minibatch, clip_param, vf_coeff, entropy_coeff): entropy_coeff) loss += l optimizer = optimizer.apply_gradient(grad, learning_rate=lr) - grad_norm = sum(jnp.square(g).sum() for g in jax.tree_leaves(grad)) - return optimizer, loss, grad_norm + return optimizer, loss def get_experience( - model: flax.optim.base.Optimizer, - simulators: List[agent.RemoteSimulator], - steps_per_actor: int): + model: nn.base.Model, + simulators: List[agent.RemoteSimulator], + steps_per_actor: int): """Collect experience from agents. Runs `steps_per_actor` time steps of the game for each of the `simulators`. @@ -143,11 +144,11 @@ def get_experience( return all_experience def process_experience( - experience: List[List[agent.ExpTuple]], - actor_steps: int, - num_agents: int, - gamma: float, - lambda_: float): + experience: List[List[agent.ExpTuple]], + actor_steps: int, + num_agents: int, + gamma: float, + lambda_: float): """Process experience for training, including advantage estimation. Args: @@ -186,32 +187,33 @@ def process_experience( # After preprocessing, concatenate data from all agents. trajectories = (states, actions, log_probs, returns, advantages) trajectories = tuple(map( - lambda x: onp.reshape( - x, (num_agents * actor_steps,) + x.shape[2:]), - trajectories)) + lambda x: onp.reshape( + x, (num_agents * actor_steps,) + x.shape[2:]), + trajectories)) return trajectories def train( - optimizer: flax.optim.base.Optimizer, - game: str, - steps_total: int, - flags_: flags._flagvalues.FlagValues): + optimizer: flax.optim.base.Optimizer, + flags_: flags._flagvalues.FlagValues): """Main training loop. Args: optimizer: optimizer for the actor-critic model - game: string specifying the Atari game from gym package - steps total: total number of frames (env steps) to train on + flags_: object holding hyperparameters and the training information Returns: optimizer: the trained optimizer """ - simulators = [agent.RemoteSimulator(game) for i in range(flags_.num_agents)] + game = flags_.game + 'NoFrameskip-v4' + simulators = [agent.RemoteSimulator(game) + for _ in range(flags_.num_agents)] model_dir = '/tmp/ppo_training/' summary_writer = tensorboard.SummaryWriter(model_dir) - loop_steps = steps_total // (flags_.num_agents * flags_.actor_steps) + loop_steps = flags_.total_frames // (flags_.num_agents * flags_.actor_steps) log_frequency = 40 checkpoint_frequency = 200 + + for s in range(loop_steps): # Bookkeeping and testing. if s % log_frequency == 0: @@ -227,15 +229,16 @@ def train( all_experiences = get_experience(optimizer.target, simulators, flags_.actor_steps) trajectories = process_experience( - all_experiences, flags_.actor_steps, flags_.num_agents, flags_.gamma, - flags_.lambda_) + all_experiences, flags_.actor_steps, flags_.num_agents, flags_.gamma, + flags_.lambda_) lr = flags_.learning_rate * alpha clip_param = flags_.clip_param * alpha for e in range(flags_.num_epochs): permutation = onp.random.permutation( flags_.num_agents * flags_.actor_steps) trajectories = tuple(map(lambda x: x[permutation], trajectories)) - optimizer, loss, last_iter_grad_norm = train_step( - optimizer, trajectories, clip_param, flags_.vf_coeff, - flags_.entropy_coeff, lr, flags_.batch_size) - return optimizer \ No newline at end of file + optimizer, loss = train_step( + optimizer, trajectories, clip_param, flags_.vf_coeff, + flags_.entropy_coeff, lr, flags_.batch_size) + return optimizer + \ No newline at end of file diff --git a/examples/ppo/ppo_lib_test.py b/examples/ppo/ppo_lib_test.py index 11e8262dcc..f12ffc2048 100644 --- a/examples/ppo/ppo_lib_test.py +++ b/examples/ppo/ppo_lib_test.py @@ -21,7 +21,7 @@ def test_gae_shape_on_random(self): discount = 0.99 gae_param = 0.95 adv = ppo_lib.gae_advantages(rewards, terminal_masks, values, discount, - gae_param) + gae_param) self.assertEqual(adv.shape, (steps, envs)) def test_gae_hardcoded(self): #test on small example that can be verified by hand @@ -39,7 +39,7 @@ def test_gae_hardcoded(self): class TestEnvironmentPreprocessing(absltest.TestCase): def choose_random_game(self): games = ['BeamRider', 'Breakout', 'Pong', - 'Qbert', 'Seaquest', 'SpaceInvaders'] + 'Qbert', 'Seaquest', 'SpaceInvaders'] ind = onp.random.choice(len(games)) return games[ind] + "NoFrameskip-v4" @@ -109,8 +109,9 @@ def test_optimization_step(self): key, subkey = jax.random.split(key) model = models.create_model(subkey, num_outputs) optimizer = models.create_optimizer(model, learning_rate=lr) - optimizer, _, _ = ppo_lib.train_step( - optimizer, trn_data, clip_param, vf_coeff, entropy_coeff, lr, batch_size) + optimizer, _ = ppo_lib.train_step( + optimizer, trn_data, clip_param, vf_coeff, entropy_coeff, lr, + batch_size) self.assertTrue(isinstance(optimizer, flax.optim.base.Optimizer)) if __name__ == '__main__': diff --git a/examples/ppo/ppo_main.py b/examples/ppo/ppo_main.py index f5130c50df..1a959ab1e5 100644 --- a/examples/ppo/ppo_main.py +++ b/examples/ppo/ppo_main.py @@ -9,88 +9,90 @@ FLAGS = flags.FLAGS -# default hyperparameters taken from PPO paper and openAI baselines 2 +# Default hyperparameters originate from PPO paper and openAI baselines 2. # https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py flags.DEFINE_float( - 'learning_rate', default=2.5e-4, - help=('The learning rate for the Adam optimizer.') + 'learning_rate', default=2.5e-4, + help=('The learning rate for the Adam optimizer.') ) flags.DEFINE_integer( - 'batch_size', default=256, - help=('Batch size for training.') + 'batch_size', default=256, + help=('Batch size for training.') ) flags.DEFINE_integer( - 'num_agents', default=8, - help=('Number of agents playing in parallel.') + 'num_agents', default=8, + help=('Number of agents playing in parallel.') ) flags.DEFINE_integer( - 'actor_steps', default=128, - help=('Batch size for training.') + 'actor_steps', default=128, + help=('Batch size for training.') ) flags.DEFINE_integer( - 'num_epochs', default=3, - help=('Number of epochs per each unroll of the policy.') + 'num_epochs', default=3, + help=('Number of epochs per each unroll of the policy.') ) flags.DEFINE_float( - 'gamma', default=0.99, - help=('Discount parameter.') + 'gamma', default=0.99, + help=('Discount parameter.') ) flags.DEFINE_float( - 'lambda_', default=0.95, - help=('Generalized Advantage Estimation parameter.') + 'lambda_', default=0.95, + help=('Generalized Advantage Estimation parameter.') ) flags.DEFINE_float( - 'clip_param', default=0.1, - help=('The PPO clipping parameter used to clamp ratios in loss function.') + 'clip_param', default=0.1, + help=('The PPO clipping parameter used to clamp ratios in loss function.') ) flags.DEFINE_float( - 'vf_coeff', default=0.5, - help=('Weighs value function loss in the total loss.') + 'vf_coeff', default=0.5, + help=('Weighs value function loss in the total loss.') ) flags.DEFINE_float( - 'entropy_coeff', default=0.01, - help=('Weighs entropy bonus in the total loss.') + 'entropy_coeff', default=0.01, + help=('Weighs entropy bonus in the total loss.') ) flags.DEFINE_boolean( - 'decaying_lr_and_clip_param', default=True, - help=(('Linearly decay learning rate and clipping parameter to zero during ' - 'the training.')) + 'decaying_lr_and_clip_param', default=True, + help=(('Linearly decay learning rate and clipping parameter to zero during ' + 'the training.')) ) flags.DEFINE_string( - 'game', default='Pong', - help=('The Atari game used.') + 'game', default='Pong', + help=('The Atari game used.') ) flags.DEFINE_string( - 'logdir', default='/tmp/ppo_training', - help=('Directory to set .') + 'logdir', default='/tmp/ppo_training', + help=('Directory to save checkpoints and logging info.') +) + +flags.DEFINE_integer( + 'total_frames', default=40000000, + help=('Length of training (total number of frames to be seen).') ) def main(argv): - game = "Pong" - game += "NoFrameskip-v4" + game = FLAGS.game + 'NoFrameskip-v4' num_actions = env_utils.get_num_actions(game) - print(f"Playing {game} with {num_actions} actions") - num_agents = FLAGS.num_agents - total_frames = 40000000 + print(f'Playing {game} with {num_actions} actions') key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) model = models.create_model(subkey, num_outputs=num_actions) optimizer = models.create_optimizer(model, learning_rate=FLAGS.learning_rate) del model - optimizer = ppo_lib.train(optimizer, game, total_frames, FLAGS) + optimizer = ppo_lib.train(optimizer, FLAGS) if __name__ == '__main__': app.run(main) diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py index c550d111f5..334a71513c 100644 --- a/examples/ppo/test_episodes.py +++ b/examples/ppo/test_episodes.py @@ -30,10 +30,7 @@ def policy_test(n_episodes: int, model: flax.nn.base.Model, game: str): action = onp.random.choice(probs.shape[1], p=probabilities) obs, reward, done, _ = test_env.step(action) total_reward += reward - if not done: - next_state = agent.get_state(obs) - else: - next_state = None + next_state = agent.get_state(obs) if not done else None state = next_state if done: break From 68b871335ad6d0c49ebd8b5a28870da95582fd43 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Mon, 28 Sep 2020 14:29:22 +0000 Subject: [PATCH 42/48] Save checkpoints less frequently --- examples/ppo/ppo_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index 7369c3e81e..2abbad7cbe 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -211,7 +211,7 @@ def train( summary_writer = tensorboard.SummaryWriter(model_dir) loop_steps = flags_.total_frames // (flags_.num_agents * flags_.actor_steps) log_frequency = 40 - checkpoint_frequency = 200 + checkpoint_frequency = 500 for s in range(loop_steps): From 57dd0a37e81bc582a531254f570befb2cd2a28e3 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Tue, 29 Sep 2020 12:24:49 +0000 Subject: [PATCH 43/48] Update the README --- examples/ppo/README.md | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/examples/ppo/README.md b/examples/ppo/README.md index 5def2ceecf..88e570d670 100644 --- a/examples/ppo/README.md +++ b/examples/ppo/README.md @@ -1,12 +1,41 @@ # Proximal Policy Optimization -Uses the Proximal Policy Optimization algorithm ([Schulman et al., 2017](https://arxiv.org/abs/1707.06347)) to learn playing Atari games. +Uses the Proximal Policy Optimization algorithm ([Schulman et al., 2017](https://arxiv.org/abs/1707.06347)) +to learn playing Atari games. ## Requirements -* This example depends on the `gym`, `opencv-python` and `atari-py` packages in addition to `jax` and `flax`. +This example depends on the `gym`, `opencv-python` and `atari-py` packages +in addition to `jax` and `flax`. + +## Supported setups + +The example should run with other configurations and hardware, but was explicitly +tested on the following: + +| Hardware | Game | Training time | Total frames seen | TensorBoard.dev | +| --- | --- | --- | --- | --- | +| 1x V100 GPU | TBA | TBA | TBA | TBA | ## How to run -`python main.py` runs the main training loop. -Unit tests can be run using `python unit_tests.py` \ No newline at end of file +Running `python ppo_main.py` will run the example with default +(hyper)parameters, i.e. for 40M frames on the Pong game. You can override the +default parameters, for example + +```python ppo_main.py --game=Seaquest --total_frames=20000000 --decaying_lr_and_clip_param=False --logdir=/tmp/seaquest``` + +will train the model on 20M Seaquest frames with constant (i.e. not linearly +decaying) learning rate and PPO clipping parameter. Checkpoints and tensorboard +files will be saved in `/tmp/seaquest`. + +Unit tests can be run using `python ppo_lib_test.py`. + +## How to run on Google Cloud TPU + +It is also possible to run this code on Google Cloud TPU. For detailed +instructions on the required setup, please refer to the [WMT example readme](https://github.com/google/flax/tree/master/examples/wmt). + +## Owners + +Jonathan Heek @jheek, Wojciech Rzadkowski @wrzadkow \ No newline at end of file From d7a8fa45f98200bdc7ac113c5020d3d28fef49f7 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Tue, 29 Sep 2020 13:29:59 +0000 Subject: [PATCH 44/48] Don't send values and log probs to remote process and back --- examples/ppo/agent.py | 4 ++-- examples/ppo/ppo_lib.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/ppo/agent.py b/examples/ppo/agent.py index bed3bb0587..403a136b56 100644 --- a/examples/ppo/agent.py +++ b/examples/ppo/agent.py @@ -45,10 +45,10 @@ def rcv_action_send_exp(conn, game): state = get_state(obs) while not done: conn.send(state) - action, value, log_prob = conn.recv() + action = conn.recv() obs, reward, done, _ = env.step(action) next_state = get_state(obs) if not done else None - experience = ExpTuple(state, action, reward, value, log_prob, done) + experience = (state, action, reward, done) conn.send(experience) if done: break diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index 2abbad7cbe..1a085e0fd2 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -134,11 +134,13 @@ def get_experience( for i, sim in enumerate(simulators): probabilities = probs[i] action = onp.random.choice(probs.shape[1], p=probabilities) - # In principle, one could avoid sending value and log prob back and forth. - sim.conn.send((action, values[i, 0], log_probs[i][action])) + sim.conn.send(action) experiences = [] - for sim in simulators: - sample = sim.conn.recv() + for i, sim in enumerate(simulators): + state, action, reward, done = sim.conn.recv() + value = values[i, 0] + log_prob = log_probs[i][action] + sample = agent.ExpTuple(state, action, reward, value, log_prob, done) experiences.append(sample) all_experience.append(experiences) return all_experience From f9e37fea906977416ecae3c0cb876c11738c2c3a Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Wed, 30 Sep 2020 08:00:58 +0000 Subject: [PATCH 45/48] Add tensorboard.dev trace --- examples/ppo/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/ppo/README.md b/examples/ppo/README.md index 88e570d670..139a509b3d 100644 --- a/examples/ppo/README.md +++ b/examples/ppo/README.md @@ -5,7 +5,7 @@ to learn playing Atari games. ## Requirements -This example depends on the `gym`, `opencv-python` and `atari-py` packages +This example depends on the `gym`, `opencv-python` and `atari-py` packages in addition to `jax` and `flax`. ## Supported setups @@ -15,7 +15,7 @@ tested on the following: | Hardware | Game | Training time | Total frames seen | TensorBoard.dev | | --- | --- | --- | --- | --- | -| 1x V100 GPU | TBA | TBA | TBA | TBA | +| 1x V100 GPU | Qbert | 9h 27m 8s | 40M | [2020-09-30](https://tensorboard.dev/experiment/1pacpbxxRz2di3NIOFkHoA/#scalars) | ## How to run @@ -29,7 +29,7 @@ will train the model on 20M Seaquest frames with constant (i.e. not linearly decaying) learning rate and PPO clipping parameter. Checkpoints and tensorboard files will be saved in `/tmp/seaquest`. -Unit tests can be run using `python ppo_lib_test.py`. +Unit tests can be run using `python ppo_lib_test.py`. ## How to run on Google Cloud TPU From 70d21f7165d7d8bff413c4ba394cc589ac446e76 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Wed, 30 Sep 2020 09:02:54 +0000 Subject: [PATCH 46/48] Remove unneeded function get_state() --- examples/ppo/agent.py | 10 +++------- examples/ppo/test_episodes.py | 4 ++-- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/examples/ppo/agent.py b/examples/ppo/agent.py index 403a136b56..2cdfa8e573 100644 --- a/examples/ppo/agent.py +++ b/examples/ppo/agent.py @@ -42,19 +42,15 @@ def rcv_action_send_exp(conn, game): while True: obs = env.reset() done = False - state = get_state(obs) + # Observations fetched from Atari env need additional batch dimension. + state = obs[None, ...] while not done: conn.send(state) action = conn.recv() obs, reward, done, _ = env.step(action) - next_state = get_state(obs) if not done else None + next_state = obs[None, ...] if not done else None experience = (state, action, reward, done) conn.send(experience) if done: break state = next_state - -def get_state(observation): - """Convert Atari env observation into a NumPy array, add batch dimension.""" - state = onp.array(observation) - return state[None, ...] diff --git a/examples/ppo/test_episodes.py b/examples/ppo/test_episodes.py index 334a71513c..084b880810 100644 --- a/examples/ppo/test_episodes.py +++ b/examples/ppo/test_episodes.py @@ -21,7 +21,7 @@ def policy_test(n_episodes: int, model: flax.nn.base.Model, game: str): test_env = env_utils.create_env(game, clip_rewards=False) for _ in range(n_episodes): obs = test_env.reset() - state = agent.get_state(obs) + state = obs[None, ...] # add batch dimension total_reward = 0.0 for t in itertools.count(): log_probs, _ = agent.policy_action(model, state) @@ -30,7 +30,7 @@ def policy_test(n_episodes: int, model: flax.nn.base.Model, game: str): action = onp.random.choice(probs.shape[1], p=probabilities) obs, reward, done, _ = test_env.step(action) total_reward += reward - next_state = agent.get_state(obs) if not done else None + next_state = obs[None, ...] if not done else None state = next_state if done: break From 342786bf99acf9783b0e3300af5f6f471c44ab93 Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Wed, 30 Sep 2020 10:53:03 +0000 Subject: [PATCH 47/48] Small type hints & docstrings enhancement --- examples/ppo/agent.py | 5 +++-- examples/ppo/models.py | 5 +++-- examples/ppo/ppo_lib.py | 10 +++++++--- examples/ppo/ppo_lib_test.py | 2 ++ 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/ppo/agent.py b/examples/ppo/agent.py index 2cdfa8e573..1024f156eb 100644 --- a/examples/ppo/agent.py +++ b/examples/ppo/agent.py @@ -24,7 +24,8 @@ class RemoteSimulator: An object of this class is created for every agent. """ - def __init__(self, game): + def __init__(self, game: str): + """Start the remote process and create Pipe() to communicate with it.""" parent_conn, child_conn = multiprocessing.Pipe() self.proc = multiprocessing.Process( target=rcv_action_send_exp, args=(child_conn, game)) @@ -32,7 +33,7 @@ def __init__(self, game): self.proc.start() -def rcv_action_send_exp(conn, game): +def rcv_action_send_exp(conn, game: str): """Run the remote agents. Receive action from the main learner, perform one step of simulation and diff --git a/examples/ppo/models.py b/examples/ppo/models.py index 19b38a4311..0e23302226 100644 --- a/examples/ppo/models.py +++ b/examples/ppo/models.py @@ -1,5 +1,6 @@ """Class and functions to define and initialize the actor-critic model.""" +import numpy as onp import flax from flax import nn import jax.numpy as jnp @@ -39,14 +40,14 @@ def apply(self, x, num_outputs): value = nn.Dense(x, features=1, name='value', dtype=dtype) return policy_log_probabilities, value -def create_model(key, num_outputs): +def create_model(key: onp.ndarray, num_outputs: int): input_dims = (1, 84, 84, 4) # (minibatch, height, width, stacked frames) module = ActorCritic.partial(num_outputs=num_outputs) _, initial_par = module.init_by_shape(key, [(input_dims, jnp.float32)]) model = flax.nn.Model(module, initial_par) return model -def create_optimizer(model, learning_rate): +def create_optimizer(model: nn.base.Model, learning_rate: float): optimizer_def = flax.optim.Adam(learning_rate) optimizer = optimizer_def.create(model) return optimizer diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index 1a085e0fd2..6e71d29c6d 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -17,7 +17,12 @@ @functools.partial(jax.vmap, in_axes=(1, 1, 1, None, None), out_axes=1) @jax.jit -def gae_advantages(rewards, terminal_masks, values, discount, gae_param): +def gae_advantages( + rewards: onp.ndarray, + terminal_masks: onp.ndarray, + values: onp.ndarray, + discount: float, + gae_param: float): """Use Generalized Advantage Estimation (GAE) to compute advantages. As defined by eqs. (11-12) in PPO paper arXiv: 1707.06347. Implementation uses @@ -53,7 +58,7 @@ def gae_advantages(rewards, terminal_masks, values, discount, gae_param): @functools.partial(jax.jit, static_argnums=(6)) def train_step( optimizer: flax.optim.base.Optimizer, - trajectories: Tuple[onp.array, onp.array, onp.array, onp.array, onp.array], + trajectories: Tuple, clip_param: float, vf_coeff: float, entropy_coeff: float, @@ -243,4 +248,3 @@ def train( optimizer, trajectories, clip_param, flags_.vf_coeff, flags_.entropy_coeff, lr, flags_.batch_size) return optimizer - \ No newline at end of file diff --git a/examples/ppo/ppo_lib_test.py b/examples/ppo/ppo_lib_test.py index f12ffc2048..65d649e73a 100644 --- a/examples/ppo/ppo_lib_test.py +++ b/examples/ppo/ppo_lib_test.py @@ -1,3 +1,5 @@ +"""Unit tests for the PPO example.""" + import jax import flax from flax import nn From a4dade8ca04ebd6da1914a0800d54668414aa2da Mon Sep 17 00:00:00 2001 From: Wojciech Rzadkowski Date: Wed, 30 Sep 2020 22:19:02 +0000 Subject: [PATCH 48/48] Use ml_collections for hyperparameter handling --- examples/ppo/README.md | 12 +++-- examples/ppo/default_config.py | 40 ++++++++++++++++ examples/ppo/ppo_lib.py | 39 ++++++++-------- examples/ppo/ppo_main.py | 83 ++++------------------------------ 4 files changed, 79 insertions(+), 95 deletions(-) create mode 100644 examples/ppo/default_config.py diff --git a/examples/ppo/README.md b/examples/ppo/README.md index 139a509b3d..318955b985 100644 --- a/examples/ppo/README.md +++ b/examples/ppo/README.md @@ -20,10 +20,16 @@ tested on the following: ## How to run Running `python ppo_main.py` will run the example with default -(hyper)parameters, i.e. for 40M frames on the Pong game. You can override the -default parameters, for example +(hyper)parameters, i.e. for 40M frames on the Pong game. -```python ppo_main.py --game=Seaquest --total_frames=20000000 --decaying_lr_and_clip_param=False --logdir=/tmp/seaquest``` +By default logging info and checkpoints will be stored in `/tmp/ppo_training` +directory. This can be overriden as follows: + +```python ppo_main.py --logdir=/my_fav_directory``` + +You can also override the default (hyper)parameters, for example + +```python ppo_main.py --config.game=Seaquest --config.total_frames=20000000 --config.decaying_lr_and_clip_param=False --logdir=/tmp/seaquest``` will train the model on 20M Seaquest frames with constant (i.e. not linearly decaying) learning rate and PPO clipping parameter. Checkpoints and tensorboard diff --git a/examples/ppo/default_config.py b/examples/ppo/default_config.py new file mode 100644 index 0000000000..91d3745a18 --- /dev/null +++ b/examples/ppo/default_config.py @@ -0,0 +1,40 @@ +"""Definitions of default hyperparameters.""" + +import ml_collections + +def get_config(): + """Get the default configuration. + + The default hyperparameters originate from PPO paper arXiv:1707.06347 + and openAI baselines 2:: + https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py + """ + config = ml_collections.ConfigDict() + # The Atari game used. + config.game = 'Pong' + # Total number of frames seen during training. + config.total_frames = 40000000 + # The learning rate for the Adam optimizer. + config.learning_rate = 2.5e-4 + # Batch size used in training. + config.batch_size = 256 + # Number of agents playing in parallel. + config.num_agents = 8 + # Number of steps each agent performs in one policy unroll. + config.actor_steps = 128 + # Number of training epochs per each unroll of the policy. + config.num_epochs = 3 + # RL discount parameter. + config.gamma = 0.99 + # Generalized Advantage Estimation parameter. + config.lambda_ = 0.95 + # The PPO clipping parameter used to clamp ratios in loss function. + config.clip_param = 0.1 + # Weight of value function loss in the total loss. + config.vf_coeff = 0.5 + # Weight of entropy bonus in the total loss. + config.entropy_coeff = 0.01 + # Linearly decay learning rate and clipping parameter to zero during + # the training. + config.decaying_lr_and_clip_param = True + return config diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index 6e71d29c6d..1e53303b68 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -2,7 +2,6 @@ import functools from typing import Tuple, List -from absl import flags import jax import jax.random import jax.numpy as jnp @@ -11,6 +10,7 @@ from flax import nn from flax.metrics import tensorboard from flax.training import checkpoints +import ml_collections import agent import test_episodes @@ -201,22 +201,23 @@ def process_experience( def train( optimizer: flax.optim.base.Optimizer, - flags_: flags._flagvalues.FlagValues): + config: ml_collections.ConfigDict, + model_dir: str): """Main training loop. Args: optimizer: optimizer for the actor-critic model - flags_: object holding hyperparameters and the training information + config: object holding hyperparameters and the training information + model_dir: path to dictionary where checkpoints and logging info are stored Returns: optimizer: the trained optimizer """ - game = flags_.game + 'NoFrameskip-v4' + game = config.game + 'NoFrameskip-v4' simulators = [agent.RemoteSimulator(game) - for _ in range(flags_.num_agents)] - model_dir = '/tmp/ppo_training/' + for _ in range(config.num_agents)] summary_writer = tensorboard.SummaryWriter(model_dir) - loop_steps = flags_.total_frames // (flags_.num_agents * flags_.actor_steps) + loop_steps = config.total_frames // (config.num_agents * config.actor_steps) log_frequency = 40 checkpoint_frequency = 500 @@ -225,26 +226,26 @@ def train( # Bookkeeping and testing. if s % log_frequency == 0: score = test_episodes.policy_test(1, optimizer.target, game) - frames = s * flags_.num_agents * flags_.actor_steps + frames = s * config.num_agents * config.actor_steps summary_writer.scalar('game_score', score, frames) print(f'Step {s}:\nframes seen {frames}\nscore {score}\n\n') if s % checkpoint_frequency == 0: checkpoints.save_checkpoint(model_dir, optimizer, s) # Core training code. - alpha = 1. - s/loop_steps if flags_.decaying_lr_and_clip_param else 1. - all_experiences = get_experience(optimizer.target, simulators, - flags_.actor_steps) + alpha = 1. - s/loop_steps if config.decaying_lr_and_clip_param else 1. + all_experiences = get_experience( + optimizer.target, simulators, config.actor_steps) trajectories = process_experience( - all_experiences, flags_.actor_steps, flags_.num_agents, flags_.gamma, - flags_.lambda_) - lr = flags_.learning_rate * alpha - clip_param = flags_.clip_param * alpha - for e in range(flags_.num_epochs): + all_experiences, config.actor_steps, config.num_agents, config.gamma, + config.lambda_) + lr = config.learning_rate * alpha + clip_param = config.clip_param * alpha + for e in range(config.num_epochs): permutation = onp.random.permutation( - flags_.num_agents * flags_.actor_steps) + config.num_agents * config.actor_steps) trajectories = tuple(map(lambda x: x[permutation], trajectories)) optimizer, loss = train_step( - optimizer, trajectories, clip_param, flags_.vf_coeff, - flags_.entropy_coeff, lr, flags_.batch_size) + optimizer, trajectories, clip_param, config.vf_coeff, + config.entropy_coeff, lr, config.batch_size) return optimizer diff --git a/examples/ppo/ppo_main.py b/examples/ppo/ppo_main.py index 1a959ab1e5..24b6bc2fef 100644 --- a/examples/ppo/ppo_main.py +++ b/examples/ppo/ppo_main.py @@ -1,7 +1,9 @@ +import os from absl import flags from absl import app import jax import jax.random +from ml_collections import config_flags import ppo_lib import models @@ -9,90 +11,25 @@ FLAGS = flags.FLAGS -# Default hyperparameters originate from PPO paper and openAI baselines 2. -# https://github.com/openai/baselines/blob/master/baselines/ppo2/defaults.py - -flags.DEFINE_float( - 'learning_rate', default=2.5e-4, - help=('The learning rate for the Adam optimizer.') -) - -flags.DEFINE_integer( - 'batch_size', default=256, - help=('Batch size for training.') -) - -flags.DEFINE_integer( - 'num_agents', default=8, - help=('Number of agents playing in parallel.') -) - -flags.DEFINE_integer( - 'actor_steps', default=128, - help=('Batch size for training.') -) - -flags.DEFINE_integer( - 'num_epochs', default=3, - help=('Number of epochs per each unroll of the policy.') -) - -flags.DEFINE_float( - 'gamma', default=0.99, - help=('Discount parameter.') -) - -flags.DEFINE_float( - 'lambda_', default=0.95, - help=('Generalized Advantage Estimation parameter.') -) - -flags.DEFINE_float( - 'clip_param', default=0.1, - help=('The PPO clipping parameter used to clamp ratios in loss function.') -) - -flags.DEFINE_float( - 'vf_coeff', default=0.5, - help=('Weighs value function loss in the total loss.') -) - -flags.DEFINE_float( - 'entropy_coeff', default=0.01, - help=('Weighs entropy bonus in the total loss.') -) - -flags.DEFINE_boolean( - 'decaying_lr_and_clip_param', default=True, - help=(('Linearly decay learning rate and clipping parameter to zero during ' - 'the training.')) -) - -flags.DEFINE_string( - 'game', default='Pong', - help=('The Atari game used.') -) - flags.DEFINE_string( 'logdir', default='/tmp/ppo_training', - help=('Directory to save checkpoints and logging info.') -) + help=('Directory to save checkpoints and logging info.')) -flags.DEFINE_integer( - 'total_frames', default=40000000, - help=('Length of training (total number of frames to be seen).') -) +config_flags.DEFINE_config_file( + 'config', os.path.join(os.path.dirname(__file__), 'default_config.py'), + 'File path to the default configuration file.') def main(argv): - game = FLAGS.game + 'NoFrameskip-v4' + config = FLAGS.config + game = config.game + 'NoFrameskip-v4' num_actions = env_utils.get_num_actions(game) print(f'Playing {game} with {num_actions} actions') key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key) model = models.create_model(subkey, num_outputs=num_actions) - optimizer = models.create_optimizer(model, learning_rate=FLAGS.learning_rate) + optimizer = models.create_optimizer(model, learning_rate=config.learning_rate) del model - optimizer = ppo_lib.train(optimizer, FLAGS) + optimizer = ppo_lib.train(optimizer, config, FLAGS.logdir) if __name__ == '__main__': app.run(main)