Skip to content

Commit

Permalink
Merge pull request #123 from FLAIROx/jax-update
Browse files Browse the repository at this point in the history
Jax version update
  • Loading branch information
amacrutherford authored Dec 19, 2024
2 parents 98b572c + 908c905 commit b02b108
Show file tree
Hide file tree
Showing 45 changed files with 216 additions and 212 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvcr.io/nvidia/jax:23.10-py3
FROM nvcr.io/nvidia/jax:24.10-py3

# Create user
ARG UID
Expand Down
14 changes: 7 additions & 7 deletions baselines/IPPO/ippo_cnn_overcooked.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,9 @@ def _env_step(runner_state, unused):

shaped_reward = info.pop("shaped_reward")
current_timestep = update_step*config["NUM_STEPS"]*config["NUM_ENVS"]
reward = jax.tree_map(lambda x,y: x+y*rew_shaping_anneal(current_timestep), reward, shaped_reward)
reward = jax.tree.map(lambda x,y: x+y*rew_shaping_anneal(current_timestep), reward, shaped_reward)

info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
transition = Transition(
batchify(done, env.agents, config["NUM_ACTORS"]).squeeze(),
action,
Expand Down Expand Up @@ -345,13 +345,13 @@ def _loss_fn(params, traj_batch, gae, targets):
), "batch size must be equal to number of steps * number of actors"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree_map(
batch = jax.tree.map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree_map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree_map(
minibatches = jax.tree.map(
lambda x: jnp.reshape(
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
),
Expand All @@ -375,7 +375,7 @@ def callback(metric):
wandb.log(metric)

update_step = update_step + 1
metric = jax.tree_map(lambda x: x.mean(), metric)
metric = jax.tree.map(lambda x: x.mean(), metric)
metric["update_step"] = update_step
metric["env_step"] = update_step * config["NUM_STEPS"] * config["NUM_ENVS"]
jax.debug.callback(callback, metric)
Expand Down Expand Up @@ -413,7 +413,7 @@ def single_run(config):

print("** Saving Results **")
filename = f'{config["ENV_NAME"]}_{layout_name}_seed{config["SEED"]}'
train_state = jax.tree_map(lambda x: x[0], out["runner_state"][0])
train_state = jax.tree.map(lambda x: x[0], out["runner_state"][0])
state_seq = get_rollout(train_state.params, config)
viz = OvercookedVisualizer()
# agent_view_size is hardcoded as it determines the padding around the layout.
Expand Down
8 changes: 4 additions & 4 deletions baselines/IPPO/ippo_ff_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,15 @@ def _env_step(runner_state, unused):
action = pi.sample(seed=_rng)
log_prob = pi.log_prob(action)
env_act = unbatchify(action, env.agents, config["NUM_ENVS"], env.num_agents)
env_act = jax.tree_map(lambda x: x.squeeze(), env_act)
env_act = jax.tree.map(lambda x: x.squeeze(), env_act)

# STEP ENV
rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0))(
rng_step, env_state, env_act
)
info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
transition = Transition(
done_batch,
Expand Down Expand Up @@ -258,11 +258,11 @@ def _loss_fn(params, traj_batch, gae, targets):
batch = (traj_batch, advantages.squeeze(), targets.squeeze())
permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

shuffled_batch = jax.tree_util.tree_map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=1), batch
)

minibatches = jax.tree_util.tree_map(
minibatches = jax.tree.map(
lambda x: jnp.swapaxes(
jnp.reshape(
x,
Expand Down
12 changes: 6 additions & 6 deletions baselines/IPPO/ippo_ff_mabrax.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _env_step(runner_state, unused):
rng_step, env_state, env_act,
)

info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
transition = Transition(
batchify(done, env.agents, config["NUM_ACTORS"]).squeeze(),
action,
Expand Down Expand Up @@ -258,13 +258,13 @@ def _loss_fn(params, traj_batch, gae, targets):
), "batch size must be equal to number of steps * number of actors"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree_map(
batch = jax.tree.map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree_map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree_map(
minibatches = jax.tree.map(
lambda x: jnp.reshape(
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
),
Expand Down Expand Up @@ -292,8 +292,8 @@ def callback(metric):

update_count = update_count + 1
r0 = {"ratio0": loss_info["ratio"][0,0].mean()}
loss_info = jax.tree_map(lambda x: x.mean(), loss_info)
metric = jax.tree_map(lambda x: x.mean(), metric)
loss_info = jax.tree.map(lambda x: x.mean(), loss_info)
metric = jax.tree.map(lambda x: x.mean(), metric)
metric["update_step"] = update_count
metric["env_step"] = update_count * config["NUM_STEPS"] * config["NUM_ENVS"]
metric = {**metric, **loss_info, **r0}
Expand Down
12 changes: 6 additions & 6 deletions baselines/IPPO/ippo_ff_mpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _env_step(runner_state, unused):
rng_step, env_state, env_act,
)

info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
transition = Transition(
batchify(done, env.agents, config["NUM_ACTORS"]).squeeze(),
action,
Expand Down Expand Up @@ -255,13 +255,13 @@ def _loss_fn(params, traj_batch, gae, targets):
), "batch size must be equal to number of steps * number of actors"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree_map(
batch = jax.tree.map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree_map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree_map(
minibatches = jax.tree.map(
lambda x: jnp.reshape(
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
),
Expand All @@ -288,8 +288,8 @@ def callback(metric):

r0 = {"ratio0": loss_info["ratio"][0,0].mean()}
# jax.debug.print('ratio0 {x}', x=r0["ratio0"])
loss_info = jax.tree_map(lambda x: x.mean(), loss_info)
metric = jax.tree_map(lambda x: x.mean(), metric)
loss_info = jax.tree.map(lambda x: x.mean(), loss_info)
metric = jax.tree.map(lambda x: x.mean(), metric)
metric = {**metric, **loss_info, **r0}
jax.experimental.io_callback(callback, None, metric)
runner_state = (train_state, env_state, last_obs, rng)
Expand Down
8 changes: 4 additions & 4 deletions baselines/IPPO/ippo_ff_mpe_facmac.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _env_step(runner_state, unused):
rng_step, env_state, env_act,
)

info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
transition = Transition(
batchify(done, env.agents, config["NUM_ACTORS"]).squeeze(),
action,
Expand Down Expand Up @@ -252,13 +252,13 @@ def _loss_fn(params, traj_batch, gae, targets):
), "batch size must be equal to number of steps * number of actors"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree_map(
batch = jax.tree.map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree_map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree_map(
minibatches = jax.tree.map(
lambda x: jnp.reshape(
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
),
Expand Down
14 changes: 7 additions & 7 deletions baselines/IPPO/ippo_ff_overcooked.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _env_step(runner_state, unused):
info["reward"] = reward["agent_0"]

current_timestep = update_step*config["NUM_STEPS"]*config["NUM_ENVS"]
reward = jax.tree_map(lambda x,y: x+y*rew_shaping_anneal(current_timestep), reward, info["shaped_reward"])
reward = jax.tree.map(lambda x,y: x+y*rew_shaping_anneal(current_timestep), reward, info["shaped_reward"])

transition = Transition(
batchify(done, env.agents, config["NUM_ACTORS"]).squeeze(),
Expand Down Expand Up @@ -318,13 +318,13 @@ def _loss_fn(params, traj_batch, gae, targets):
), "batch size must be equal to number of steps * number of actors"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree_map(
batch = jax.tree.map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree_map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree_map(
minibatches = jax.tree.map(
lambda x: jnp.reshape(
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
),
Expand Down Expand Up @@ -353,7 +353,7 @@ def callback(metric):
metric
)
update_step = update_step + 1
metric = jax.tree_map(lambda x: x.mean(), metric)
metric = jax.tree.map(lambda x: x.mean(), metric)
metric["update_step"] = update_step
metric["env_step"] = update_step*config["NUM_STEPS"]*config["NUM_ENVS"]
jax.debug.callback(callback, metric)
Expand Down Expand Up @@ -393,7 +393,7 @@ def main(config):
out = jax.vmap(train_jit)(rngs)

filename = f'{config["ENV_NAME"]}_{layout_name}'
train_state = jax.tree_map(lambda x: x[0], out["runner_state"][0])
train_state = jax.tree.map(lambda x: x[0], out["runner_state"][0])
state_seq = get_rollout(train_state, config)
viz = OvercookedVisualizer()
# agent_view_size is hardcoded as it determines the padding around the layout.
Expand All @@ -415,7 +415,7 @@ def main(config):
plt.savefig(f'{filename}.png')
# animate first seed
train_state = jax.tree_map(lambda x: x[0], out["runner_state"][0])
train_state = jax.tree.map(lambda x: x[0], out["runner_state"][0])
state_seq = get_rollout(train_state, config)
viz = OvercookedVisualizer()
# agent_view_size is hardcoded as it determines the padding around the layout.
Expand Down
8 changes: 4 additions & 4 deletions baselines/IPPO/ippo_ff_switch_riddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _env_step(runner_state, unused):
rng_step, env_state, env_act,
)

info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
transition = Transition(
batchify(done, env.agents, config["NUM_ACTORS"]).squeeze(),
action,
Expand Down Expand Up @@ -247,13 +247,13 @@ def _loss_fn(params, traj_batch, gae, targets):
), "batch size must be equal to number of steps * number of actors"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree_map(
batch = jax.tree.map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree_map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree_map(
minibatches = jax.tree.map(
lambda x: jnp.reshape(
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
),
Expand Down
10 changes: 5 additions & 5 deletions baselines/IPPO/ippo_rnn_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,14 @@ def _env_step(runner_state, unused):
action = pi.sample(seed=_rng)
log_prob = pi.log_prob(action)
env_act = unbatchify(action, env.agents, config["NUM_ENVS"], env.num_agents)
env_act = jax.tree_map(lambda x: x.squeeze(), env_act)
env_act = jax.tree.map(lambda x: x.squeeze(), env_act)
# STEP ENV
rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0))(
rng_step, env_state, env_act
)
info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
transition = Transition(
jnp.tile(done["__all__"], env.num_agents),
Expand Down Expand Up @@ -312,11 +312,11 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
batch = (init_hstate, traj_batch, advantages.squeeze(), targets.squeeze())
permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

shuffled_batch = jax.tree_util.tree_map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=1), batch
)

minibatches = jax.tree_util.tree_map(
minibatches = jax.tree.map(
lambda x: jnp.swapaxes(
jnp.reshape(
x,
Expand All @@ -342,7 +342,7 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
train_state = update_state[0]
metric = traj_batch.info
ratio_0 = loss_info[1][3].at[0,0].get().mean()
loss_info = jax.tree_map(lambda x: x.mean(), loss_info)
loss_info = jax.tree.map(lambda x: x.mean(), loss_info)
metric["loss"] = {
"total_loss": loss_info[0],
"value_loss": loss_info[1][0],
Expand Down
10 changes: 5 additions & 5 deletions baselines/IPPO/ippo_rnn_mpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _env_step(runner_state, unused):
obsv, env_state, reward, done, info = jax.vmap(
env.step, in_axes=(0, 0, 0)
)(rng_step, env_state, env_act)
info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
transition = Transition(
jnp.tile(done["__all__"], env.num_agents),
Expand Down Expand Up @@ -334,11 +334,11 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
)
permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

shuffled_batch = jax.tree_util.tree_map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=1), batch
)

minibatches = jax.tree_util.tree_map(
minibatches = jax.tree.map(
lambda x: jnp.swapaxes(
jnp.reshape(
x,
Expand Down Expand Up @@ -377,14 +377,14 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
)
train_state = update_state[0]
metric = traj_batch.info
metric = jax.tree_map(
metric = jax.tree.map(
lambda x: x.reshape(
(config["NUM_STEPS"], config["NUM_ENVS"], env.num_agents)
),
traj_batch.info,
)
ratio_0 = loss_info[1][3].at[0,0].get().mean()
loss_info = jax.tree_map(lambda x: x.mean(), loss_info)
loss_info = jax.tree.map(lambda x: x.mean(), loss_info)
metric["loss"] = {
"total_loss": loss_info[0],
"value_loss": loss_info[1][0],
Expand Down
10 changes: 5 additions & 5 deletions baselines/IPPO/ippo_rnn_smax.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _env_step(runner_state, unused):
obsv, env_state, reward, done, info = jax.vmap(
env.step, in_axes=(0, 0, 0)
)(rng_step, env_state, env_act)
info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze()
transition = Transition(
jnp.tile(done["__all__"], env.num_agents),
Expand Down Expand Up @@ -346,11 +346,11 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
)
permutation = jax.random.permutation(_rng, config["NUM_ACTORS"])

shuffled_batch = jax.tree_util.tree_map(
shuffled_batch = jax.tree.map(
lambda x: jnp.take(x, permutation, axis=1), batch
)

minibatches = jax.tree_util.tree_map(
minibatches = jax.tree.map(
lambda x: jnp.swapaxes(
jnp.reshape(
x,
Expand Down Expand Up @@ -389,14 +389,14 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets):
)
train_state = update_state[0]
metric = traj_batch.info
metric = jax.tree_map(
metric = jax.tree.map(
lambda x: x.reshape(
(config["NUM_STEPS"], config["NUM_ENVS"], env.num_agents)
),
traj_batch.info,
)
ratio_0 = loss_info[1][3].at[0,0].get().mean()
loss_info = jax.tree_map(lambda x: x.mean(), loss_info)
loss_info = jax.tree.map(lambda x: x.mean(), loss_info)
metric["loss"] = {
"total_loss": loss_info[0],
"value_loss": loss_info[1][0],
Expand Down
Loading

0 comments on commit b02b108

Please sign in to comment.