Skip to content

Commit

Permalink
Merge pull request #125 from ai4co/bugfix-deepcopy
Browse files Browse the repository at this point in the history
Bugfix deepcopy
  • Loading branch information
fedebotu authored Mar 2, 2024
2 parents 2b8bcb9 + 91dcc1f commit 847c48a
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 23 deletions.
2 changes: 1 addition & 1 deletion rl4co/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.2"
__version__ = "0.3.3"
12 changes: 6 additions & 6 deletions rl4co/envs/routing/svrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def _make_spec(self, td_params: TensorDict = None):
"""Make the observation and action specs from the parameters."""
self.observation_spec = CompositeSpec(
locs=BoundedTensorSpec(
minimum=self.min_loc,
maximum=self.max_loc,
low=self.min_loc,
high=self.max_loc,
shape=(self.num_loc + 1, 2),
dtype=torch.float32,
),
Expand All @@ -74,8 +74,8 @@ def _make_spec(self, td_params: TensorDict = None):
dtype=torch.int64,
),
skills=BoundedTensorSpec(
minimum=self.min_skill,
maximum=self.max_skill,
low=self.min_skill,
high=self.max_skill,
shape=(self.num_loc, 1),
dtype=torch.float32,
),
Expand All @@ -88,8 +88,8 @@ def _make_spec(self, td_params: TensorDict = None):
self.action_spec = BoundedTensorSpec(
shape=(1,),
dtype=torch.int64,
minimum=0,
maximum=self.num_loc + 1,
low=0,
high=self.num_loc + 1,
)
self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,), dtype=torch.float32)
self.done_spec = UnboundedDiscreteTensorSpec(shape=(1,), dtype=torch.bool)
Expand Down
33 changes: 28 additions & 5 deletions rl4co/models/nn/dec_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,39 @@ def get_decoding_strategy(decoding_strategy, **config):


class DecodingStrategy:
"""Base class for decoding strategies. Subclasses should implement the :meth:`_step` method.
Includes hooks for pre and post main decoding operations.
Args:
multistart (bool, optional): Whether to use multistart decoding. Defaults to False.
num_starts (int, optional): Number of starts for multistart decoding. Defaults to None.
select_start_nodes_fn (Callable, optional): Function to select start nodes. Defaults to select_start_nodes.
"""

name = "base"

def __init__(self, multistart=False, num_starts=None, **kwargs) -> None:
def __init__(
self,
multistart=False,
num_starts=None,
select_start_nodes_fn=select_start_nodes,
**kwargs,
) -> None:

self.actions = []
self.logp = []
self.multistart = multistart
self.num_starts = num_starts
self.select_start_nodes_fn = select_start_nodes_fn

def _step(
self, logp: torch.Tensor, td: TensorDict, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Main decoding operation. This method should be implemented by subclasses."""
raise NotImplementedError("Must be implemented by subclass")

def pre_decoder_hook(self, td: TensorDict, env: RL4COEnvBase):
"""Pre decoding hook. This method is called before the main decoding operation."""
# Multi-start decoding. If num_starts is None, we use the number of actions in the action mask
if self.multistart:
if self.num_starts is None:
Expand All @@ -61,7 +80,7 @@ def pre_decoder_hook(self, td: TensorDict, env: RL4COEnvBase):

# Multi-start decoding: first action is chosen by ad-hoc node selection
if self.num_starts > 1:
action = select_start_nodes(td, env, self.num_starts)
action = self.select_start_nodes_fn(td, env, self.num_starts)

# Expand td to batch_size * num_starts
td = batchify(td, self.num_starts)
Expand All @@ -78,6 +97,7 @@ def pre_decoder_hook(self, td: TensorDict, env: RL4COEnvBase):
return td, env, self.num_starts

def post_decoder_hook(self, td, env):
"""Post decoding hook. This method is called after the main decoding operation."""
assert (
len(self.logp) > 0
), "No outputs were collected because all environments were done. Check your initial state"
Expand All @@ -87,6 +107,7 @@ def post_decoder_hook(self, td, env):
def step(
self, logp: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]:
"""Main decoding operation. This method calls the :meth:`_step` method and collects the outputs."""
assert not logp.isinf().all(1).any()

logp, selected_actions, td = self._step(logp, mask, td, **kwargs)
Expand All @@ -103,11 +124,12 @@ class Greedy(DecodingStrategy):
name = "greedy"

def __init__(self, multistart=False, num_starts=None, **kwargs) -> None:
super().__init__(multistart=multistart, num_starts=num_starts)
super().__init__(multistart=multistart, num_starts=num_starts, **kwargs)

def _step(
self, logp: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]:
"""Select the action with the highest log probability."""
# [BS], [BS]
_, selected = logp.max(1)

Expand All @@ -122,11 +144,12 @@ class Sampling(DecodingStrategy):
name = "sampling"

def __init__(self, multistart=False, num_starts=None, **kwargs) -> None:
super().__init__(multistart=multistart, num_starts=num_starts)
super().__init__(multistart=multistart, num_starts=num_starts, **kwargs)

def _step(
self, logp: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]:
"""Sample an action with a multinomial distribution given by the log probabilities."""
probs = logp.exp()
selected = torch.multinomial(probs, 1).squeeze(1)

Expand Down Expand Up @@ -171,7 +194,7 @@ def pre_decoder_hook(self, td: TensorDict, env: RL4COEnvBase, **kwargs):
self.beam_width = get_num_starts(td, env.name)

# select start nodes. TODO: include first step in beam search as well
action = select_start_nodes(td, env, self.beam_width)
action = self.select_start_nodes_fn(td, env, self.beam_width)

# Expand td to batch_size * beam_width
td = batchify(td, self.beam_width)
Expand Down
23 changes: 14 additions & 9 deletions rl4co/models/zoo/common/autoregressive/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class AutoregressiveDecoder(nn.Module):
We suppose environments in the `done` state are still available for sampling. This is because in NCO we need to
wait for all the environments to reach a terminal state before we can stop the decoding process. This is in
contrast with the TorchRL framework (at the moment) where the `env.rollout` function automatically resets.
You may follow tighter integration with TorchRL here: https://github.com/kaist-silab/rl4co/issues/72.
You may follow tighter integration with TorchRL here: https://github.com/ai4co/rl4co/issues/72.
Args:
env_name: environment name to solve
Expand All @@ -60,7 +60,7 @@ class AutoregressiveDecoder(nn.Module):

def __init__(
self,
env_name: [str, RL4COEnvBase],
env_name: Union[str, RL4COEnvBase],
embedding_dim: int,
num_heads: int,
use_graph_context: bool = True,
Expand Down Expand Up @@ -111,8 +111,6 @@ def __init__(

self.select_start_nodes_fn = select_start_nodes_fn

self.decode_strategy = None

def forward(
self,
td: TensorDict,
Expand Down Expand Up @@ -154,19 +152,26 @@ def forward(
# Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
cached_embeds = self._precompute_cache(embeddings, td=td)

# setup decoding strategy
self.decode_strategy: DecodingStrategy = get_decoding_strategy(
# If `select_start_nodes_fn` is not being passed, we use the class attribute
if "select_start_nodes_fn" not in strategy_kwargs:
strategy_kwargs["select_start_nodes_fn"] = self.select_start_nodes_fn

# Setup decoding strategy
decode_strategy: DecodingStrategy = get_decoding_strategy(
decode_type, **strategy_kwargs
)
td, env, num_starts = self.decode_strategy.pre_decoder_hook(td, env)

# Pre-decoding hook: used for the initial step(s) of the decoding strategy
td, env, num_starts = decode_strategy.pre_decoder_hook(td, env)

# Main decoding: loop until all sequences are done
while not td["done"].all():
log_p, mask = self._get_log_p(cached_embeds, td, softmax_temp, num_starts)
td = self.decode_strategy.step(log_p, mask, td)
td = decode_strategy.step(log_p, mask, td)
td = env.step(td)["next"]

outputs, actions, td, env = self.decode_strategy.post_decoder_hook(td, env)
# Post-decoding hook: used for the final step(s) of the decoding strategy
outputs, actions, td, env = decode_strategy.post_decoder_hook(td, env)

if calc_reward:
td.set("reward", env.get_reward(td, actions))
Expand Down
2 changes: 1 addition & 1 deletion rl4co/models/zoo/common/autoregressive/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class AutoregressivePolicy(nn.Module):

def __init__(
self,
env_name: [str, RL4COEnvBase] = "tsp",
env_name: Union[str, RL4COEnvBase] = "tsp",
encoder: nn.Module = None,
decoder: nn.Module = None,
init_embedding: nn.Module = None,
Expand Down
9 changes: 8 additions & 1 deletion tests/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from rl4co.models import AutoregressivePolicy, PointerNetworkPolicy
from rl4co.utils.test_utils import generate_env_data
from rl4co.utils.ops import select_start_nodes


# Main autorergressive policy: rollout over multiple envs since it is the base
Expand All @@ -25,7 +26,13 @@ def test_base_policy_multistart(env_name, size=20, batch_size=2):
td = env.reset(x)
policy = AutoregressivePolicy(env.name)
num_starts = size // 2 if env.name in ["pdp"] else size
out = policy(td, env, decode_type="multistart_greedy", num_starts=num_starts)
out = policy(
td,
env,
decode_type="multistart_greedy",
num_starts=num_starts,
select_start_nodes_fn=select_start_nodes,
)
assert out["reward"].shape == (
batch_size * num_starts,
) # to evaluate, we could just unbatchify
Expand Down

0 comments on commit 847c48a

Please sign in to comment.