Skip to content

Commit

Permalink
[Chore] woopsie revert new syntax ai4co#237
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Jan 14, 2025
1 parent 8012b2c commit 25f9fcb
Showing 1 changed file with 42 additions and 22 deletions.
64 changes: 42 additions & 22 deletions rl4co/models/zoo/gfacs/model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import math

from typing import Optional, Union

import numpy as np
import scipy
from tensordict import TensorDict
import torch

from tensordict import TensorDict

from rl4co.envs.common.base import RL4COEnvBase
from rl4co.models.zoo.deepaco import DeepACO
from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline
from rl4co.models.zoo.deepaco import DeepACO
from rl4co.models.zoo.gfacs.policy import GFACSPolicy
from rl4co.utils.ops import unbatchify

Expand Down Expand Up @@ -43,7 +45,9 @@ def __init__(
):
if policy is None:
policy = GFACSPolicy(
env_name=env.name, train_with_local_search=train_with_local_search, **policy_kwargs
env_name=env.name,
train_with_local_search=train_with_local_search,
**policy_kwargs,
)

super().__init__(
Expand All @@ -64,7 +68,9 @@ def __init__(
@property
def beta(self) -> float:
return self.beta_min + (self.beta_max - self.beta_min) * min(
math.log(self.current_epoch + 1) / math.log(self.trainer.max_epochs - self.beta_flat_epochs), 1.0
math.log(self.current_epoch + 1)
/ math.log(self.trainer.max_epochs - self.beta_flat_epochs),
1.0,
)

def calculate_loss(
Expand All @@ -91,36 +97,50 @@ def calculate_loss(
if self.train_with_local_search:
ls_reward = policy_out["ls_reward"]
ls_advantage = ls_reward - ls_reward.mean(dim=1, keepdim=True)
weighted_advantage = advantage * (1 - self.ls_reward_aug_W) + ls_advantage * self.ls_reward_aug_W
weighted_advantage = (
advantage * (1 - self.ls_reward_aug_W)
+ ls_advantage * self.ls_reward_aug_W
)
else:
weighted_advantage = advantage

# On-policy loss
forward_flow = policy_out["log_likelihood"] + policy_out["logZ"].repeat(1, n_ants)
backward_flow = self.calculate_log_pb_uniform(policy_out["actions"], n_ants) + weighted_advantage.detach() * self.beta
backward_flow = (
self.calculate_log_pb_uniform(policy_out["actions"], n_ants)
+ weighted_advantage.detach() * self.beta
)
tb_loss = torch.pow(forward_flow - backward_flow, 2).mean()

# Off-policy loss
if self.train_with_local_search:
ls_forward_flow = policy_out["ls_log_likelihood"] + policy_out["ls_logZ"].repeat(1, n_ants)
ls_backward_flow = self.calculate_log_pb_uniform(policy_out["ls_actions"], n_ants) + ls_advantage.detach() * self.beta
ls_forward_flow = policy_out["ls_log_likelihood"] + policy_out[
"ls_logZ"
].repeat(1, n_ants)
ls_backward_flow = (
self.calculate_log_pb_uniform(policy_out["ls_actions"], n_ants)
+ ls_advantage.detach() * self.beta
)
ls_tb_loss = torch.pow(ls_forward_flow - ls_backward_flow, 2).mean()
tb_loss = tb_loss + ls_tb_loss

return tb_loss

def calculate_log_pb_uniform(self, actions: torch.Tensor, n_ants: int):
if self.env.name == "tsp":
return math.log(1 / 2 * actions.size(1))
elif self.env.name == "cvrp":
_a1 = actions.detach().cpu().numpy()
# shape: (batch, max_tour_length)
n_nodes = np.count_nonzero(_a1, axis=1)
_a2 = _a1[:, 1:] - _a1[:, :-1]
n_routes = np.count_nonzero(_a2, axis=1) - n_nodes
_a3 = _a1[:, 2:] - _a1[:, :-2]
n_multinode_routes = np.count_nonzero(_a3, axis=1) - n_nodes
log_b_p = - scipy.special.gammaln(n_routes + 1) - n_multinode_routes * math.log(2)
return unbatchify(torch.from_numpy(log_b_p).to(actions.device), n_ants)
else:
raise ValueError(f"Unknown environment: {self.env.name}")
match self.env.name:
case "tsp":
return math.log(1 / 2 * actions.size(1))
case "cvrp":
_a1 = actions.detach().cpu().numpy()
# shape: (batch, max_tour_length)
n_nodes = np.count_nonzero(_a1, axis=1)
_a2 = _a1[:, 1:] - _a1[:, :-1]
n_routes = np.count_nonzero(_a2, axis=1) - n_nodes
_a3 = _a1[:, 2:] - _a1[:, :-2]
n_multinode_routes = np.count_nonzero(_a3, axis=1) - n_nodes
log_b_p = -scipy.special.gammaln(
n_routes + 1
) - n_multinode_routes * math.log(2)
return unbatchify(torch.from_numpy(log_b_p).to(actions.device), n_ants)
case _:
raise ValueError(f"Unknown environment: {self.env.name}")

0 comments on commit 25f9fcb

Please sign in to comment.