Skip to content

Commit 888095f

Browse files
committed
[BugFix] Fix collector devices
ghstack-source-id: 21eb8cc Pull-Request: #3223
1 parent 3d1748f commit 888095f

File tree

2 files changed

+28
-33
lines changed

2 files changed

+28
-33
lines changed

test/test_collector.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
RandomPolicy,
8080
)
8181
from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule
82+
from torchrl.weight_update import SharedMemWeightSyncScheme
8283

8384
if os.getenv("PYTORCH_TEST_FBCODE"):
8485
IS_FB = True
@@ -3835,14 +3836,26 @@ def all_worker_ids(self) -> list[int] | list[torch.device]:
38353836

38363837
@pytest.mark.skipif(not _has_cuda, reason="requires cuda another device than CPU.")
38373838
@pytest.mark.skipif(not _has_gym, reason="requires gym")
3838-
def test_weight_update(self):
3839+
@pytest.mark.parametrize(
3840+
"weight_updater", ["scheme_shared", "scheme_pipe", "weight_updater"]
3841+
)
3842+
def test_weight_update(self, weight_updater):
38393843
device = "cuda:0"
38403844
env_maker = lambda: GymEnv(PENDULUM_VERSIONED(), device="cpu")
38413845
policy_factory = lambda: TensorDictModule(
3842-
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
3843-
).to(device)
3846+
nn.Linear(3, 1, device=device), in_keys=["observation"], out_keys=["action"]
3847+
)
38443848
policy = policy_factory()
38453849
policy_weights = TensorDict.from_module(policy)
3850+
kwargs = {}
3851+
if weight_updater == "scheme_shared":
3852+
kwargs = {"weight_sync_schemes": {"policy": SharedMemWeightSyncScheme()}}
3853+
elif weight_updater == "scheme_pipe":
3854+
kwargs = {"weight_sync_schemes": {"policy": SharedMemWeightSyncScheme()}}
3855+
elif weight_updater == "weight_updater":
3856+
kwargs = {"weight_updater": self.MPSWeightUpdaterBase(policy_weights, 2)}
3857+
else:
3858+
raise NotImplementedError
38463859

38473860
collector = MultiSyncDataCollector(
38483861
create_env_fn=[env_maker, env_maker],
@@ -3854,7 +3867,7 @@ def test_weight_update(self):
38543867
reset_at_each_iter=False,
38553868
device=device,
38563869
storing_device="cpu",
3857-
weight_updater=self.MPSWeightUpdaterBase(policy_weights, 2),
3870+
**kwargs,
38583871
)
38593872

38603873
collector.update_policy_weights_()

torchrl/weight_update/weight_sync_schemes.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -687,36 +687,14 @@ def apply_weights(self, destination: Any, weights: Any) -> None:
687687

688688
# Auto-detect format from weights type
689689
if isinstance(weights, dict):
690-
# Apply state_dict format
691-
if isinstance(destination, nn.Module):
692-
destination.load_state_dict(weights)
693-
elif isinstance(destination, dict):
694-
destination = TensorDict(destination)
695-
weights = TensorDict(weights)
696-
destination.data.update_(weights.data)
697-
elif isinstance(destination, TensorDictBase):
698-
weights_td = TensorDict(weights)
699-
if (dest_keys := sorted(destination.keys(True, True))) != sorted(
700-
weights.keys(True, True)
701-
):
702-
weights_td = weights_td.unflatten_keys(".")
703-
weights_keys = sorted(weights_td.keys(True, True))
704-
if dest_keys != weights_keys:
705-
raise ValueError(
706-
f"The keys of the weights and destination do not match: {dest_keys} != {weights_keys}"
707-
)
708-
destination.data.update_(weights_td.data)
709-
else:
710-
raise ValueError(
711-
f"Unsupported destination type for state_dict: {type(destination)}"
712-
)
713-
elif isinstance(weights, TensorDictBase):
690+
weights = TensorDict(weights).unflatten_keys(".")
691+
692+
if isinstance(weights, TensorDictBase):
714693
# Apply TensorDict format
715694
if isinstance(destination, nn.Module):
716-
weights.to_module(destination)
717-
elif isinstance(destination, TensorDictBase):
718-
destination.data.update_(weights.data)
719-
elif isinstance(destination, dict):
695+
destination = TensorDict.from_module(destination)
696+
697+
if isinstance(destination, dict):
720698
destination_td = TensorDict(destination)
721699
if (dest_keys := sorted(destination_td.keys(True, True))) != sorted(
722700
weights.keys(True, True)
@@ -727,11 +705,15 @@ def apply_weights(self, destination: Any, weights: Any) -> None:
727705
raise ValueError(
728706
f"The keys of the weights and destination do not match: {dest_keys} != {weights_keys}"
729707
)
730-
destination_td.data.update_(weights.data)
708+
destination = destination_td
709+
710+
if isinstance(destination, TensorDictBase):
711+
destination.data.update_(weights.data)
731712
else:
732713
raise ValueError(
733714
f"Unsupported destination type for TensorDict: {type(destination)}"
734715
)
716+
735717
else:
736718
raise ValueError(
737719
f"Unsupported weights type: {type(weights)}. Expected dict or TensorDictBase."

0 commit comments

Comments
 (0)