diff --git a/ding/model/template/q_learning.py b/ding/model/template/q_learning.py index c81479f06b..d7c1f25a18 100644 --- a/ding/model/template/q_learning.py +++ b/ding/model/template/q_learning.py @@ -37,6 +37,7 @@ def __init__( norm_type: Optional[str] = None, dropout: Optional[float] = None, init_bias: Optional[float] = None, + noise: bool = False, ) -> None: """ Overview: @@ -57,6 +58,8 @@ def __init__( - dropout (:obj:`Optional[float]`): The dropout rate of the dropout layer. \ if ``None`` then default disable dropout layer. - init_bias (:obj:`Optional[float]`): The initial value of the last layer bias in the head network. \ + - noise (:obj:`bool`): Whether to use ``NoiseLinearLayer`` as ``layer_fn`` to boost exploration in \ + Q networks' MLP. Default to ``False``. """ super(DQN, self).__init__() # Squeeze data from tuple, list or dict to single object. For example, from (4, ) to 4 @@ -90,7 +93,8 @@ def __init__( layer_num=head_layer_num, activation=activation, norm_type=norm_type, - dropout=dropout + dropout=dropout, + noise=noise, ) else: self.head = head_cls( @@ -99,7 +103,8 @@ def __init__( head_layer_num, activation=activation, norm_type=norm_type, - dropout=dropout + dropout=dropout, + noise=noise, ) if init_bias is not None and head_cls == DuelingHead: # Zero the last layer bias of advantage head diff --git a/ding/policy/common_utils.py b/ding/policy/common_utils.py index 6dee15d5e5..3d602be0b4 100644 --- a/ding/policy/common_utils.py +++ b/ding/policy/common_utils.py @@ -1,9 +1,28 @@ from typing import List, Any, Dict, Callable import torch +import torch.nn as nn import numpy as np import treetensor.torch as ttorch from ding.utils.data import default_collate from ding.torch_utils import to_tensor, to_ndarray, unsqueeze, squeeze +from ding.torch_utils import NoiseLinearLayer + + +def set_noise_mode(module: nn.Module, noise_enabled: bool): + """ + Overview: + Recursively set the 'enable_noise' attribute for all NoiseLinearLayer modules within the given module. + This function is typically used in algorithms such as NoisyNet and Rainbow. + During training, 'enable_noise' should be set to True to enable noise for exploration. + During inference or evaluation, it should be set to False to disable noise for deterministic behavior. + + Arguments: + - module (:obj:`nn.Module`): The root module to search for NoiseLinearLayer instances. + - noise_enabled (:obj:`bool`): Whether to enable or disable noise. + """ + for m in module.modules(): + if isinstance(m, NoiseLinearLayer): + m.enable_noise = noise_enabled def default_preprocess_learn( diff --git a/ding/policy/dqn.py b/ding/policy/dqn.py index 8e0944f270..116ba93cf6 100644 --- a/ding/policy/dqn.py +++ b/ding/policy/dqn.py @@ -10,7 +10,7 @@ from ding.utils.data import default_collate, default_decollate from .base_policy import Policy -from .common_utils import default_preprocess_learn +from .common_utils import default_preprocess_learn, set_noise_mode @POLICY_REGISTRY.register('dqn') @@ -97,6 +97,8 @@ class DQNPolicy(Policy): discount_factor=0.97, # (int) The number of steps for calculating target q_value. nstep=1, + # (bool) Whether to use NoisyNet for exploration in both learning and collecting. Default is False. + noisy_net=False, model=dict( # (list(int)) Sequence of ``hidden_size`` of subsequent conv layers and the final dense layer. encoder_hidden_size_list=[128, 128, 64], @@ -248,6 +250,21 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: .. note:: For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``. """ + # Set noise mode for NoisyNet for exploration in learning if enabled in config + # We need to reset set_noise_mode every _forward_xxx because the model is reused across different + # phases (learn/collect/eval). + if self._cfg.noisy_net: + set_noise_mode(self._learn_model, True) + set_noise_mode(self._target_model, True) + + # A noisy network agent samples a new set of parameters after every step of optimisation. + # Between optimisation steps, the agent acts according to a fixed set of parameters (weights and biases). + # This ensures that the agent always acts according to parameters that are drawn from + # the current noise distribution. + if self._cfg.noisy_net: + self._reset_noise(self._learn_model) + self._reset_noise(self._target_model) + # Data preprocessing operations, such as stack data, cpu to cuda device data = default_preprocess_learn( data, @@ -380,10 +397,17 @@ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: .. note:: For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``. """ + # Set noise mode for NoisyNet for exploration in collecting if enabled in config. + # We need to reset set_noise_mode every _forward_xxx because the model is reused across different + # phases (learn/collect/eval). + if self._cfg.noisy_net: + set_noise_mode(self._collect_model, True) + data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: data = to_device(data, self._device) + self._collect_model.eval() with torch.no_grad(): output = self._collect_model.forward(data, eps=eps) @@ -472,10 +496,16 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: .. note:: For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``. """ + # We need to reset set_noise_mode every _forward_xxx because the model is reused across different + # phases (learn/collect/eval). + # Ensure that in evaluation mode noise is disabled. + set_noise_mode(self._eval_model, False) + data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: data = to_device(data, self._device) + self._eval_model.eval() with torch.no_grad(): output = self._eval_model.forward(data) @@ -533,6 +563,18 @@ def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = F ) return {'priority': td_error_per_sample.abs().tolist()} + def _reset_noise(self, model: torch.nn.Module): + r""" + Overview: + Reset the noise of model. + + Arguments: + - model (:obj:`torch.nn.Module`): the model to reset, must contain reset_noise method + """ + for m in model.modules(): + if hasattr(m, 'reset_noise'): + m.reset_noise() + @POLICY_REGISTRY.register('dqn_stdim') class DQNSTDIMPolicy(DQNPolicy): diff --git a/ding/policy/rainbow.py b/ding/policy/rainbow.py index 1efd00e90b..1309d90ad0 100644 --- a/ding/policy/rainbow.py +++ b/ding/policy/rainbow.py @@ -8,7 +8,7 @@ from ding.utils import POLICY_REGISTRY from ding.utils.data import default_collate, default_decollate from .dqn import DQNPolicy -from .common_utils import default_preprocess_learn +from .common_utils import default_preprocess_learn, set_noise_mode @POLICY_REGISTRY.register('rainbow') @@ -86,8 +86,9 @@ class RainbowDQNPolicy(DQNPolicy): discount_factor=0.99, # (int) N-step reward for target q_value estimation nstep=3, + # (bool) Whether to use NoisyNet for exploration in both learning and collecting. Default is True. + noisy_net=True, learn=dict( - # How many updates(iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. # collect data -> update policy-> collect data -> ... @@ -201,6 +202,11 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: # ==================== self._learn_model.train() self._target_model.train() + + # Set noise mode for NoisyNet for exploration in learning if enabled in config + set_noise_mode(self._learn_model, True) + set_noise_mode(self._target_model, True) + # reset noise of noisenet for both main model and target model self._reset_noise(self._learn_model) self._reset_noise(self._target_model) @@ -262,12 +268,16 @@ def _forward_collect(self, data: dict, eps: float) -> dict: ReturnsKeys - necessary: ``action`` """ + # Set noise mode for NoisyNet for exploration in collecting if enabled in config + # We need to reset set_noise_mode every _forward_xxx because the model is reused across + # different phases (learn/collect/eval). + set_noise_mode(self._collect_model, True) + data_id = list(data.keys()) data = default_collate(list(data.values())) if self._cuda: data = to_device(data, self._device) self._collect_model.eval() - self._reset_noise(self._collect_model) with torch.no_grad(): output = self._collect_model.forward(data, eps=eps) if self._cuda: diff --git a/ding/torch_utils/network/nn_module.py b/ding/torch_utils/network/nn_module.py index 64a21edfe4..c8bdf44301 100644 --- a/ding/torch_utils/network/nn_module.py +++ b/ding/torch_utils/network/nn_module.py @@ -637,7 +637,10 @@ class NoiseLinearLayer(nn.Module): def __init__(self, in_channels: int, out_channels: int, sigma0: int = 0.4) -> None: """ Overview: - Initialize the NoiseLinearLayer class. + Initialize the NoiseLinearLayer class. The 'enable_noise' attribute enables external control over whether \ + noise is applied. + - If enable_noise is True, the layer adds noise even if the module is in evaluation mode. + - If enable_noise is False, no noise is added regardless of self.training. Arguments: - in_channels (:obj:`int`): Number of channels in the input tensor. - out_channels (:obj:`int`): Number of channels in the output tensor. @@ -654,6 +657,7 @@ def __init__(self, in_channels: int, out_channels: int, sigma0: int = 0.4) -> No self.register_buffer("weight_eps", torch.empty(out_channels, in_channels)) self.register_buffer("bias_eps", torch.empty(out_channels)) self.sigma0 = sigma0 + self.enable_noise = False self.reset_parameters() self.reset_noise() @@ -703,7 +707,8 @@ def forward(self, x: torch.Tensor): Returns: - output (:obj:`torch.Tensor`): The output tensor with noise. """ - if self.training: + # Determine whether to add noise: + if self.enable_noise: return F.linear( x, self.weight_mu + self.weight_sigma * self.weight_eps, diff --git a/ding/torch_utils/network/tests/test_nn_module.py b/ding/torch_utils/network/tests/test_nn_module.py index 8fdc7845ee..ac44e15cba 100644 --- a/ding/torch_utils/network/tests/test_nn_module.py +++ b/ding/torch_utils/network/tests/test_nn_module.py @@ -5,7 +5,7 @@ from ding.torch_utils import build_activation from ding.torch_utils.network.nn_module import MLP, conv1d_block, conv2d_block, fc_block, deconv2d_block, \ ChannelShuffle, one_hot, NearestUpsample, BilinearUpsample, binary_encode, weight_init_, NaiveFlatten, \ - normed_linear, normed_conv2d + normed_linear, normed_conv2d, NoiseLinearLayer batch_size = 2 in_channels = 2 @@ -238,3 +238,27 @@ def test_flatten(self): model3 = NaiveFlatten(1, 3) output3 = model2(inputs) assert output1.shape == (4, 3 * 8 * 8) + + def test_noise_linear_layer(self): + input = torch.rand(batch_size, in_channels).requires_grad_(True) + layer = NoiseLinearLayer(in_channels, out_channels, sigma0=0.5) + # No noise by default + output = self.run_model(input, layer) + assert output.shape == (batch_size, out_channels) + # Enable noise + layer.enable_noise = True + layer.reset_noise() + output_noise = self.run_model(input, layer) + assert output_noise.shape == (batch_size, out_channels) + # Check that outputs are different after resetting noise + with torch.no_grad(): + layer.reset_noise() + out1 = layer(input) + layer.reset_noise() + out2 = layer(input) + # The outputs should be different (very likely) + assert not torch.allclose(out1, out2) + # Check reset_parameters + layer.reset_parameters() + assert layer.weight_mu.shape == (out_channels, in_channels) + assert layer.bias_mu.shape == (out_channels, ) diff --git a/dizoo/atari/config/serial/demon_attack/demon_attack_dqn_config.py b/dizoo/atari/config/serial/demon_attack/demon_attack_dqn_config.py new file mode 100644 index 0000000000..9669bd1392 --- /dev/null +++ b/dizoo/atari/config/serial/demon_attack/demon_attack_dqn_config.py @@ -0,0 +1,60 @@ +from easydict import EasyDict + +demon_attack_dqn_config = dict( + exp_name='DemonAttack_dqn_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=1e6, + env_id='DemonAttackNoFrameskip-v4', + frame_stack=4, + ), + policy=dict( + cuda=True, + priority=False, + model=dict( + obs_shape=[4, 84, 84], + action_shape=6, + encoder_hidden_size_list=[128, 128, 512], + noise=True, + ), + nstep=3, + discount_factor=0.99, + learn=dict( + update_per_collect=10, + batch_size=32, + learning_rate=0.0001, + target_update_freq=500, + ), + noisy_net=True, + collect=dict(n_sample=96), + eval=dict(evaluator=dict(eval_freq=4000, )), + other=dict( + eps=dict( + type='exp', + start=1., + end=0.05, + decay=250000, + ), + replay_buffer=dict(replay_buffer_size=100000, ), + ), + ), +) +demon_attack_dqn_config = EasyDict(demon_attack_dqn_config) +main_config = demon_attack_dqn_config +demon_attack_dqn_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='dqn'), +) +demon_attack_dqn_create_config = EasyDict(demon_attack_dqn_create_config) +create_config = demon_attack_dqn_create_config + +if __name__ == '__main__': + # or you can enter `ding -m serial -c demon_attack_dqn_config.py -s 0` + from ding.entry import serial_pipeline + serial_pipeline((main_config, create_config), seed=0, max_env_step=int(10e6))