Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/dropq #1036

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,8 @@ def __init__(
normalize_images: bool = True,
n_critics: int = 2,
share_features_extractor: bool = True,
dropout_rate: float = 0.0,
layer_norm: bool = False,
):
super().__init__(
observation_space,
Expand All @@ -846,18 +848,21 @@ def __init__(
self.n_critics = n_critics
self.q_networks = []
for idx in range(n_critics):
q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
q_net = create_mlp(
features_dim + action_dim, 1, net_arch, activation_fn, dropout_rate=dropout_rate, layer_norm=layer_norm
)
q_net = nn.Sequential(*q_net)
self.add_module(f"qf{idx}", q_net)
self.q_networks.append(q_net)

def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]:
def forward(self, obs: th.Tensor, actions: th.Tensor, q_networks=None) -> Tuple[th.Tensor, ...]:
q_networks = q_networks or self.q_networks
# Learn the features extractor using the policy loss only
# when the features_extractor is shared with the actor
with th.set_grad_enabled(not self.share_features_extractor):
features = self.extract_features(obs)
qvalue_input = th.cat([features, actions], dim=1)
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
return tuple(q_net(qvalue_input) for q_net in q_networks)

def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
"""
Expand Down
14 changes: 13 additions & 1 deletion stable_baselines3/common/torch_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def create_mlp(
net_arch: List[int],
activation_fn: Type[nn.Module] = nn.ReLU,
squash_output: bool = False,
dropout_rate: float = 0.0,
layer_norm: bool = False,
) -> List[nn.Module]:
"""
Create a multi layer perceptron (MLP), which is
Expand All @@ -117,12 +119,22 @@ def create_mlp(
"""

if len(net_arch) > 0:
modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()]
additional_modules = []
if dropout_rate > 0.0:
additional_modules.append(nn.Dropout(p=dropout_rate))
if layer_norm:
additional_modules.append(nn.LayerNorm(net_arch[0]))
modules = [nn.Linear(input_dim, net_arch[0])] + additional_modules + [activation_fn()]

else:
modules = []

for idx in range(len(net_arch) - 1):
modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1]))
if dropout_rate > 0.0:
modules.append(nn.Dropout(p=dropout_rate))
if layer_norm:
modules.append(nn.LayerNorm(net_arch[idx + 1]))
modules.append(activation_fn())

if output_dim > 0:
Expand Down
5 changes: 5 additions & 0 deletions stable_baselines3/sac/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ def __init__(
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
# For the critic only
dropout_rate: float = 0.0,
layer_norm: bool = False,
):
super().__init__(
observation_space,
Expand Down Expand Up @@ -263,6 +266,8 @@ def __init__(
"n_critics": n_critics,
"net_arch": critic_arch,
"share_features_extractor": share_features_extractor,
"dropout_rate": dropout_rate,
"layer_norm": layer_norm,
}
)

Expand Down
41 changes: 25 additions & 16 deletions stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
policy_delay: int = 1,
ent_coef: Union[str, float] = "auto",
target_update_interval: int = 1,
target_entropy: Union[str, float] = "auto",
Expand Down Expand Up @@ -144,6 +145,7 @@ def __init__(
self.ent_coef = ent_coef
self.target_update_interval = target_update_interval
self.ent_coef_optimizer = None
self.policy_delay = policy_delay

if _init_setup_model:
self._setup_model()
Expand Down Expand Up @@ -203,6 +205,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
actor_losses, critic_losses = [], []

for gradient_step in range(gradient_steps):
self._n_updates += 1
update_actor = self._n_updates % self.policy_delay == 0
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

Expand All @@ -211,17 +215,19 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
self.actor.reset_noise()

# Action by the current actor for the sampled state
actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
log_prob = log_prob.reshape(-1, 1)
if update_actor:
actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
log_prob = log_prob.reshape(-1, 1)

ent_coef_loss = None
if self.ent_coef_optimizer is not None:
# Important: detach the variable from the graph
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
ent_coef = th.exp(self.log_ent_coef.detach())
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
ent_coef_losses.append(ent_coef_loss.item())
if update_actor:
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
ent_coef_losses.append(ent_coef_loss.item())
else:
ent_coef = self.ent_coef_tensor

Expand All @@ -237,6 +243,9 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
with th.no_grad():
# Select action according to policy
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# For REDQ, sample q networks to be used
# q_networks_indices = np.random.permutation(len(self.critic_target.q_networks))[:2]
# q_networks = [q_net for idx, q_net in enumerate(self.critic_target.q_networks) if idx in q_networks_indices]
# Compute the next Q values: min over all critics targets
next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
Expand All @@ -260,25 +269,25 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:

# Compute actor loss
# Alternative: actor_loss = th.mean(log_prob - qf1_pi)
# Min over all critic networks
q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1)
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
actor_losses.append(actor_loss.item())

# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()
if update_actor:
q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1)
# Note: REDQ and DropQ does a mean here
# min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
mean_qf_pi = th.mean(q_values_pi, dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - mean_qf_pi).mean()
actor_losses.append(actor_loss.item())

# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()

# Update target networks
if gradient_step % self.target_update_interval == 0:
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
# Copy running stats, see GH issue #996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

self._n_updates += gradient_steps

self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/ent_coef", np.mean(ent_coefs))
self.logger.record("train/actor_loss", np.mean(actor_losses))
Expand Down
11 changes: 11 additions & 0 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ def test_sac(ent_coef):
model.learn(total_timesteps=200)


def test_dropq():
model = SAC(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64], layer_norm=True, dropout_rate=0.005),
verbose=1,
buffer_size=250,
)
model.learn(total_timesteps=300)


@pytest.mark.parametrize("n_critics", [1, 3])
def test_n_critics(n_critics):
# Test SAC with different number of critics, for TD3, n_critics=1 corresponds to DDPG
Expand Down