Skip to content

Commit

Permalink
[Feat] changed structure of stepwise ppo
Browse files Browse the repository at this point in the history
  • Loading branch information
LTluttmann committed Sep 5, 2024
1 parent ffeffba commit 961bd56
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 35 deletions.
1 change: 0 additions & 1 deletion configs/experiment/scheduling/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ defaults:
logger:
wandb:
project: "rl4co"
log_model: "all"
group: "${env.name}-${env.generator_params.num_jobs}-${env.generator_params.num_machines}"
tags: ???
name: ???
Expand Down
2 changes: 1 addition & 1 deletion configs/experiment/scheduling/matnet-ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ model:
val_batch_size: 512
test_batch_size: 64
mini_batch_size: 512
n_start: 8
n_start: 4
env:
stepwise_reward: True
83 changes: 51 additions & 32 deletions rl4co/models/rl/ppo/stepwise_ppo_ep_rew.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def __init__(
policy: nn.Module,
n_start: int = 0,
clip_range: float = 0.2, # epsilon of PPO
update_timestep: int = 1,
buffer_size: int = 100_000,
ppo_epochs: int = 2, # inner epoch, K
batch_size: int = 256,
mini_batch_size: int = 256,
rollout_batch_size: int = 256,
vf_lambda: float = 0.5, # lambda of Value function fitting
entropy_lambda: float = 0.01, # lambda of entropy bonus
max_grad_norm: float = 0.5, # max gradient norm
entropy_lambda: float = 0.0, # lambda of entropy bonus
max_grad_norm: float = 1.0, # max gradient norm
buffer_storage_device: str = "gpu",
metrics: dict = {
"train": ["loss", "surrogate_loss", "value_loss", "entropy"],
Expand All @@ -67,10 +67,10 @@ def __init__(
self.rb = make_replay_buffer(buffer_size, mini_batch_size, buffer_storage_device)
self.scaler = RewardScaler(reward_scale)
self.n_start = n_start
self.rollout_batch_size = rollout_batch_size
self.ppo_cfg = {
"clip_range": clip_range,
"ppo_epochs": ppo_epochs,
"update_timestep": update_timestep,
"mini_batch_size": mini_batch_size,
"vf_lambda": vf_lambda,
"entropy_lambda": entropy_lambda,
Expand Down Expand Up @@ -133,7 +133,7 @@ def update(self, device):
"reward": previous_reward.mean(),
"loss": loss,
"surrogate_loss": surrogate_loss,
"value_loss": value_loss,
# "value_loss": value_loss,
"entropy": entropy.mean(),
}

Expand All @@ -147,47 +147,66 @@ def shared_step(
self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None
):

next_td = self.env.reset(batch)
device = next_td.device

if phase == "train":

if self.n_start > 1:
next_td = batchify(next_td, self.n_start)
for i in range(0, batch.shape[0], self.rollout_batch_size):

mini_batch = batch[i : i + self.rollout_batch_size]
rollout_td_buffer = []
next_td = self.env.reset(mini_batch)
device = next_td.device

if self.n_start > 1:
next_td = batchify(next_td, self.n_start)

n_steps = 0
while not next_td["done"].all():

td_buffer = []
while not next_td["done"].all():
with torch.no_grad():
td = self.policy_old.act(
next_td, self.env, phase="train", temp=2.0
)

with torch.no_grad():
td = self.policy_old.act(next_td, self.env, phase="train")
rollout_td_buffer.append(td)
# get next state
next_td = self.env.step(td)["next"]
n_steps += 1

td_buffer.append(td)
# get next state
next_td = self.env.step(td)["next"]
# get reward of action
reward = self.env.get_reward(next_td, 1)
# get rewards
reward = self.env.get_reward(next_td, 1) / n_steps

if self.n_start > 1:
reward_unbatched = unbatchify(reward, self.n_start)
advantage = reward - batchify(reward_unbatched.mean(-1), self.n_start)
advantage = self.scaler(advantage)
td_buffer = [td.set("advantage", advantage) for td in td_buffer]
else:
reward = self.scaler(reward)
td_buffer = [td.set("reward", reward) for td in td_buffer]
if self.n_start > 1:
reward_unbatched = unbatchify(reward, self.n_start)
advantage = (
reward
- batchify(reward_unbatched.mean(-1), self.n_start).detach()
)
advantage = self.scaler(advantage)
rollout_td_buffer = [
td.set("advantage", advantage) for td in rollout_td_buffer
]

else:
reward = self.scaler(reward)
rollout_td_buffer = [
td.set("reward", reward) for td in rollout_td_buffer
]

# add tensordict with action, logprobs and reward information to buffer
self.rb.extend(torch.cat(td_buffer, dim=0))
# add tensordict with action, logprobs and reward information to buffer
self.rb.extend(torch.cat(rollout_td_buffer, dim=0))

# if iter mod x = 0 then update the policy (x = 1 in paper)
if batch_idx % self.ppo_cfg["update_timestep"] == 0:
out = self.update(device)
self.rb.empty()
out = self.update(device)

self.rb.empty()
torch.cuda.empty_cache()

else:
next_td = self.env.reset(batch)
out = self.policy.generate(
next_td, self.env, phase=phase, select_best=phase != "train"
)

metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx)

return {"loss": out.get("loss", None), **metrics}
2 changes: 1 addition & 1 deletion rl4co/models/zoo/l2d/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def evaluate(self, td):

return action_logprobs, value_pred, dist_entropys

def act(self, td, env, phase: str = "train"):
def act(self, td, env, phase: str = "train", temp: float = 1.0):
logits, mask = self.decoder(td, hidden=None, num_starts=0)
logprobs = process_logits(logits, mask, tanh_clipping=self.tanh_clipping)

Expand Down

0 comments on commit 961bd56

Please sign in to comment.