From 0783c0e6401031e8df069a3c163997deaa8598e7 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 4 Apr 2022 16:52:44 +0200 Subject: [PATCH 1/4] Try REDQ quickly --- stable_baselines3/common/policies.py | 5 ++-- stable_baselines3/sac/sac.py | 36 +++++++++++++++++----------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 33918b784..fe36b2e78 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -877,13 +877,14 @@ def __init__( self.add_module(f"qf{idx}", q_net) self.q_networks.append(q_net) - def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]: + def forward(self, obs: th.Tensor, actions: th.Tensor, q_networks=None) -> Tuple[th.Tensor, ...]: + q_networks = q_networks or self.q_networks # Learn the features extractor using the policy loss only # when the features_extractor is shared with the actor with th.set_grad_enabled(not self.share_features_extractor): features = self.extract_features(obs) qvalue_input = th.cat([features, actions], dim=1) - return tuple(q_net(qvalue_input) for q_net in self.q_networks) + return tuple(q_net(qvalue_input) for q_net in q_networks) def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor: """ diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 5f3a83395..1a1ae1f32 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -193,8 +193,10 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: ent_coef_losses, ent_coefs = [], [] actor_losses, critic_losses = [], [] + policy_update_delay = gradient_steps for gradient_step in range(gradient_steps): + update_actor = ((gradient_step + 1) % policy_update_delay == 0) or gradient_step == gradient_steps - 1 # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) @@ -203,8 +205,9 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: self.actor.reset_noise() # Action by the current actor for the sampled state - actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations) - log_prob = log_prob.reshape(-1, 1) + if update_actor: + actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations) + log_prob = log_prob.reshape(-1, 1) ent_coef_loss = None if self.ent_coef_optimizer is not None: @@ -212,8 +215,9 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # so we don't change it with other losses # see https://github.com/rail-berkeley/softlearning/issues/60 ent_coef = th.exp(self.log_ent_coef.detach()) - ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() - ent_coef_losses.append(ent_coef_loss.item()) + if update_actor: + ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() + ent_coef_losses.append(ent_coef_loss.item()) else: ent_coef = self.ent_coef_tensor @@ -229,8 +233,10 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: with th.no_grad(): # Select action according to policy next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) + q_networks_indices = np.random.permutation(len(self.critic_target.q_networks))[:2] + q_networks = [q_net for idx, q_net in enumerate(self.critic_target.q_networks) if idx in q_networks_indices] # Compute the next Q values: min over all critics targets - next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) + next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions, q_networks), dim=1) next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) # add entropy term next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1) @@ -253,15 +259,17 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Compute actor loss # Alternative: actor_loss = th.mean(log_prob - qf1_pi) # Mean over all critic networks - q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1) - min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) - actor_loss = (ent_coef * log_prob - min_qf_pi).mean() - actor_losses.append(actor_loss.item()) - - # Optimize the actor - self.actor.optimizer.zero_grad() - actor_loss.backward() - self.actor.optimizer.step() + if update_actor: + q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1) + # Note: REDQ does a mean here + min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) + actor_loss = (ent_coef * log_prob - min_qf_pi).mean() + actor_losses.append(actor_loss.item()) + + # Optimize the actor + self.actor.optimizer.zero_grad() + actor_loss.backward() + self.actor.optimizer.step() # Update target networks if gradient_step % self.target_update_interval == 0: From e7be8dc0522d4a01b3fd4ea6d6476e9845f84e5a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 17 Aug 2022 15:15:56 +0200 Subject: [PATCH 2/4] Implement DropQ --- stable_baselines3/common/policies.py | 6 +++++- stable_baselines3/common/torch_layers.py | 14 +++++++++++++- stable_baselines3/sac/policies.py | 5 +++++ stable_baselines3/sac/sac.py | 16 +++++++++------- tests/test_run.py | 11 +++++++++++ 5 files changed, 43 insertions(+), 9 deletions(-) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index d122acd67..c8c87e894 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -859,6 +859,8 @@ def __init__( normalize_images: bool = True, n_critics: int = 2, share_features_extractor: bool = True, + dropout_rate: float = 0.0, + layer_norm: bool = False, ): super().__init__( observation_space, @@ -873,7 +875,9 @@ def __init__( self.n_critics = n_critics self.q_networks = [] for idx in range(n_critics): - q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) + q_net = create_mlp( + features_dim + action_dim, 1, net_arch, activation_fn, dropout_rate=dropout_rate, layer_norm=layer_norm + ) q_net = nn.Sequential(*q_net) self.add_module(f"qf{idx}", q_net) self.q_networks.append(q_net) diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index f87337c62..525ffd410 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -99,6 +99,8 @@ def create_mlp( net_arch: List[int], activation_fn: Type[nn.Module] = nn.ReLU, squash_output: bool = False, + dropout_rate: float = 0.0, + layer_norm: bool = False, ) -> List[nn.Module]: """ Create a multi layer perceptron (MLP), which is @@ -117,12 +119,22 @@ def create_mlp( """ if len(net_arch) > 0: - modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()] + additional_modules = [] + if dropout_rate > 0.0: + additional_modules.append(nn.Dropout(p=dropout_rate)) + if layer_norm: + additional_modules.append(nn.LayerNorm(net_arch[0])) + modules = [nn.Linear(input_dim, net_arch[0])] + additional_modules + [activation_fn()] + else: modules = [] for idx in range(len(net_arch) - 1): modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1])) + if dropout_rate > 0.0: + modules.append(nn.Dropout(p=dropout_rate)) + if layer_norm: + modules.append(nn.LayerNorm(net_arch[idx + 1])) modules.append(activation_fn()) if output_dim > 0: diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 255bd7554..392f0881a 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -236,6 +236,9 @@ def __init__( optimizer_kwargs: Optional[Dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, + # For the critic only + dropout_rate: float = 0.0, + layer_norm: bool = False, ): super().__init__( observation_space, @@ -279,6 +282,8 @@ def __init__( "n_critics": n_critics, "net_arch": critic_arch, "share_features_extractor": share_features_extractor, + "dropout_rate": dropout_rate, + "layer_norm": layer_norm, } ) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index b7cbcf6a1..6c238c72b 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -239,10 +239,11 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: with th.no_grad(): # Select action according to policy next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) - q_networks_indices = np.random.permutation(len(self.critic_target.q_networks))[:2] - q_networks = [q_net for idx, q_net in enumerate(self.critic_target.q_networks) if idx in q_networks_indices] + # For REDQ, sample q networks to be used + # q_networks_indices = np.random.permutation(len(self.critic_target.q_networks))[:2] + # q_networks = [q_net for idx, q_net in enumerate(self.critic_target.q_networks) if idx in q_networks_indices] # Compute the next Q values: min over all critics targets - next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions, q_networks), dim=1) + next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) # add entropy term next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1) @@ -264,12 +265,13 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Compute actor loss # Alternative: actor_loss = th.mean(log_prob - qf1_pi) - # Mean over all critic networks + # Min over all critic networks if update_actor: q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1) - # Note: REDQ does a mean here - min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) - actor_loss = (ent_coef * log_prob - min_qf_pi).mean() + # Note: REDQ and DropQ does a mean here + # min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) + mean_qf_pi = th.mean(q_values_pi, dim=1, keepdim=True) + actor_loss = (ent_coef * log_prob - mean_qf_pi).mean() actor_losses.append(actor_loss.item()) # Optimize the actor diff --git a/tests/test_run.py b/tests/test_run.py index b0a9a11c5..7919d629d 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -89,6 +89,17 @@ def test_sac(ent_coef): model.learn(total_timesteps=300, eval_freq=250) +def test_dropq(): + model = SAC( + "MlpPolicy", + "Pendulum-v1", + policy_kwargs=dict(net_arch=[64, 64], layer_norm=True, dropout_rate=0.005), + verbose=1, + buffer_size=250, + ) + model.learn(total_timesteps=300) + + @pytest.mark.parametrize("n_critics", [1, 3]) def test_n_critics(n_critics): # Test SAC with different number of critics, for TD3, n_critics=1 corresponds to DDPG From 4114e9ad148922a7209c7953ce5f66a034b7894a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 28 Aug 2022 18:56:34 +0200 Subject: [PATCH 3/4] Add TODO --- stable_baselines3/sac/sac.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 7b56d46d4..384bbb092 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -202,6 +202,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: ent_coef_losses, ent_coefs = [], [] actor_losses, critic_losses = [], [] + # TODO: properly handle it when train_freq > 1 policy_update_delay = gradient_steps for gradient_step in range(gradient_steps): From aa60e711e107dfdbb4a8597e097a3074d789c78f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 29 Aug 2022 10:51:45 +0200 Subject: [PATCH 4/4] Add policy delay --- stable_baselines3/sac/sac.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 384bbb092..ed8293286 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -95,6 +95,7 @@ def __init__( replay_buffer_class: Optional[ReplayBuffer] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, + policy_delay: int = 1, ent_coef: Union[str, float] = "auto", target_update_interval: int = 1, target_entropy: Union[str, float] = "auto", @@ -145,6 +146,7 @@ def __init__( self.ent_coef = ent_coef self.target_update_interval = target_update_interval self.ent_coef_optimizer = None + self.policy_delay = policy_delay if _init_setup_model: self._setup_model() @@ -202,11 +204,10 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: ent_coef_losses, ent_coefs = [], [] actor_losses, critic_losses = [], [] - # TODO: properly handle it when train_freq > 1 - policy_update_delay = gradient_steps for gradient_step in range(gradient_steps): - update_actor = ((gradient_step + 1) % policy_update_delay == 0) or gradient_step == gradient_steps - 1 + self._n_updates += 1 + update_actor = self._n_updates % self.policy_delay == 0 # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) @@ -288,8 +289,6 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Copy running stats, see GH issue #996 polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) - self._n_updates += gradient_steps - self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/ent_coef", np.mean(ent_coefs)) self.logger.record("train/actor_loss", np.mean(actor_losses))