Skip to content

Commit 09f9d27

Browse files
committed
Update
[ghstack-poisoned]
1 parent c82098d commit 09f9d27

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

test/test_collector.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3847,10 +3847,11 @@ def test_weight_update(self, weight_updater):
38473847
)
38483848
policy = policy_factory()
38493849
policy_weights = TensorDict.from_module(policy)
3850+
kwargs = {}
38503851
if weight_updater == "scheme_shared":
3851-
kwargs = {"weight_schemes": {"policy": SharedMemWeightSyncScheme()}}
3852-
elif weight_updater == "scheme_pip":
3853-
kwargs = {"weight_schemes": {"policy": SharedMemWeightSyncScheme()}}
3852+
kwargs = {"weight_sync_schemes": {"policy": SharedMemWeightSyncScheme()}}
3853+
elif weight_updater == "scheme_pipe":
3854+
kwargs = {"weight_sync_schemes": {"policy": SharedMemWeightSyncScheme()}}
38543855
elif weight_updater == "weight_updater":
38553856
kwargs = {"weight_updater": self.MPSWeightUpdaterBase(policy_weights, 2)}
38563857
else:

0 commit comments

Comments
 (0)