Skip to content

Commit 6326468

Browse files
committed
Update
[ghstack-poisoned]
1 parent ff50b25 commit 6326468

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

test/test_collector.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
RandomPolicy,
8080
)
8181
from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule
82+
from torchrl.weight_update import SharedMemWeightSyncScheme
8283

8384
if os.getenv("PYTORCH_TEST_FBCODE"):
8485
IS_FB = True
@@ -3835,14 +3836,25 @@ def all_worker_ids(self) -> list[int] | list[torch.device]:
38353836

38363837
@pytest.mark.skipif(not _has_cuda, reason="requires cuda another device than CPU.")
38373838
@pytest.mark.skipif(not _has_gym, reason="requires gym")
3838-
def test_weight_update(self):
3839+
@pytest.mark.parametrize(
3840+
"weight_updater", ["scheme_shared", "scheme_pipe", "weight_updater"]
3841+
)
3842+
def test_weight_update(self, weight_updater):
38393843
device = "cuda:0"
38403844
env_maker = lambda: GymEnv(PENDULUM_VERSIONED(), device="cpu")
38413845
policy_factory = lambda: TensorDictModule(
3842-
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
3843-
).to(device)
3846+
nn.Linear(3, 1, device=device), in_keys=["observation"], out_keys=["action"]
3847+
)
38443848
policy = policy_factory()
38453849
policy_weights = TensorDict.from_module(policy)
3850+
if weight_updater == "scheme_shared":
3851+
kwargs = {"weight_schemes": {"policy": SharedMemWeightSyncScheme()}}
3852+
elif weight_updater == "scheme_pip":
3853+
kwargs = {"weight_schemes": {"policy": SharedMemWeightSyncScheme()}}
3854+
elif weight_updater == "weight_updater":
3855+
kwargs = {"weight_updater": self.MPSWeightUpdaterBase(policy_weights, 2)}
3856+
else:
3857+
raise NotImplementedError
38463858

38473859
collector = MultiSyncDataCollector(
38483860
create_env_fn=[env_maker, env_maker],
@@ -3854,7 +3866,7 @@ def test_weight_update(self):
38543866
reset_at_each_iter=False,
38553867
device=device,
38563868
storing_device="cpu",
3857-
weight_updater=self.MPSWeightUpdaterBase(policy_weights, 2),
3869+
**kwargs,
38583870
)
38593871

38603872
collector.update_policy_weights_()

0 commit comments

Comments
 (0)