File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -3847,10 +3847,11 @@ def test_weight_update(self, weight_updater):
38473847 )
38483848 policy = policy_factory ()
38493849 policy_weights = TensorDict .from_module (policy )
3850+ kwargs = {}
38503851 if weight_updater == "scheme_shared" :
3851- kwargs = {"weight_schemes " : {"policy" : SharedMemWeightSyncScheme ()}}
3852- elif weight_updater == "scheme_pip " :
3853- kwargs = {"weight_schemes " : {"policy" : SharedMemWeightSyncScheme ()}}
3852+ kwargs = {"weight_sync_schemes " : {"policy" : SharedMemWeightSyncScheme ()}}
3853+ elif weight_updater == "scheme_pipe " :
3854+ kwargs = {"weight_sync_schemes " : {"policy" : SharedMemWeightSyncScheme ()}}
38543855 elif weight_updater == "weight_updater" :
38553856 kwargs = {"weight_updater" : self .MPSWeightUpdaterBase (policy_weights , 2 )}
38563857 else :
You can’t perform that action at this time.
0 commit comments