-
Notifications
You must be signed in to change notification settings - Fork 422
fix(pu): fix noise layer's usage based on the original paper #866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from 2 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
5a01fde
fix(pu): fix noise layer's usage
454334c
polish(pu): polish comments
ee07a99
polish(pu): polish noisy_net config
puyuan1996 41c810d
fix(pu): fix reset_noise bug in noisy_net option
puyuan1996 681488c
fix(pu): fix enable_noise bug in rainbow
puyuan1996 688bfd7
style(pu): yapf format
puyuan1996 ad1a2a9
style(pu): yapf format
puyuan1996 3133049
style(pu): flake8 format
puyuan1996 83b5fbe
style(pu): yapf format
puyuan1996 a6be7d4
polish(pu): polish set_noise_mode when self._cfg.noisy_net is False
puyuan1996 44588d0
fature(pu): add unittest for noise_linear_layer
puyuan1996 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,7 @@ | |
| from ding.utils.data import default_collate, default_decollate | ||
|
|
||
| from .base_policy import Policy | ||
| from .common_utils import default_preprocess_learn | ||
| from .common_utils import default_preprocess_learn, set_noise_mode | ||
|
|
||
|
|
||
| @POLICY_REGISTRY.register('dqn') | ||
|
|
@@ -248,6 +248,8 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| .. note:: | ||
| For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``. | ||
| """ | ||
| set_noise_mode(self._learn_model, True) | ||
|
||
|
|
||
| # Data preprocessing operations, such as stack data, cpu to cuda device | ||
| data = default_preprocess_learn( | ||
| data, | ||
|
|
@@ -384,6 +386,12 @@ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: | |
| data = default_collate(list(data.values())) | ||
| if self._cuda: | ||
| data = to_device(data, self._device) | ||
| # Use the add_noise parameter to decide noise mode. | ||
puyuan1996 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Default to True if the parameter is not provided. | ||
| if self._cfg.collect.get("add_noise", True): | ||
| set_noise_mode(self._collect_model, True) | ||
| else: | ||
| set_noise_mode(self._collect_model, False) | ||
| self._collect_model.eval() | ||
| with torch.no_grad(): | ||
| output = self._collect_model.forward(data, eps=eps) | ||
|
|
@@ -476,6 +484,8 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: | |
| data = default_collate(list(data.values())) | ||
| if self._cuda: | ||
| data = to_device(data, self._device) | ||
| # Ensure that in evaluation mode noise is disabled. | ||
| set_noise_mode(self._eval_model, False) | ||
| self._eval_model.eval() | ||
| with torch.no_grad(): | ||
| output = self._eval_model.forward(data) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
60 changes: 60 additions & 0 deletions
60
dizoo/atari/config/serial/demon_attack/demon_attack_dqn_config.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| from easydict import EasyDict | ||
|
|
||
| demon_attack_dqn_config = dict( | ||
| exp_name='DemonAttack_dqn_collect-not-noise_seed0', | ||
| env=dict( | ||
| collector_env_num=8, | ||
| evaluator_env_num=8, | ||
| n_evaluator_episode=8, | ||
| stop_value=1e6, | ||
| env_id='DemonAttackNoFrameskip-v4', | ||
| frame_stack=4, | ||
| ), | ||
| policy=dict( | ||
| cuda=True, | ||
| priority=False, | ||
| model=dict( | ||
| obs_shape=[4, 84, 84], | ||
| action_shape=6, | ||
| encoder_hidden_size_list=[128, 128, 512], | ||
| noise=True, | ||
| ), | ||
| nstep=3, | ||
| discount_factor=0.99, | ||
| learn=dict( | ||
| update_per_collect=10, | ||
| batch_size=32, | ||
| learning_rate=0.0001, | ||
| target_update_freq=500, | ||
| ), | ||
| # collect=dict(n_sample=96, add_noise=True), | ||
| collect=dict(n_sample=96, add_noise=False), | ||
| eval=dict(evaluator=dict(eval_freq=4000, )), | ||
| other=dict( | ||
| eps=dict( | ||
| type='exp', | ||
| start=1., | ||
| end=0.05, | ||
| decay=250000, | ||
| ), | ||
| replay_buffer=dict(replay_buffer_size=100000, ), | ||
| ), | ||
| ), | ||
| ) | ||
| demon_attack_dqn_config = EasyDict(demon_attack_dqn_config) | ||
| main_config = demon_attack_dqn_config | ||
| demon_attack_dqn_create_config = dict( | ||
| env=dict( | ||
| type='atari', | ||
| import_names=['dizoo.atari.envs.atari_env'], | ||
| ), | ||
| env_manager=dict(type='subprocess'), | ||
| policy=dict(type='dqn'), | ||
| ) | ||
| demon_attack_dqn_create_config = EasyDict(demon_attack_dqn_create_config) | ||
| create_config = demon_attack_dqn_create_config | ||
|
|
||
| if __name__ == '__main__': | ||
| # or you can enter `ding -m serial -c demon_attack_dqn_config.py -s 0` | ||
| from ding.entry import serial_pipeline | ||
| serial_pipeline((main_config, create_config), seed=0, max_env_step=int(10e6)) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.