7979 RandomPolicy ,
8080)
8181from torchrl .modules import Actor , OrnsteinUhlenbeckProcessModule , SafeModule
82+ from torchrl .weight_update import SharedMemWeightSyncScheme
8283
8384if os .getenv ("PYTORCH_TEST_FBCODE" ):
8485 IS_FB = True
@@ -3835,14 +3836,25 @@ def all_worker_ids(self) -> list[int] | list[torch.device]:
38353836
38363837 @pytest .mark .skipif (not _has_cuda , reason = "requires cuda another device than CPU." )
38373838 @pytest .mark .skipif (not _has_gym , reason = "requires gym" )
3838- def test_weight_update (self ):
3839+ @pytest .mark .parametrize (
3840+ "weight_updater" , ["scheme_shared" , "scheme_pipe" , "weight_updater" ]
3841+ )
3842+ def test_weight_update (self , weight_updater ):
38393843 device = "cuda:0"
38403844 env_maker = lambda : GymEnv (PENDULUM_VERSIONED (), device = "cpu" )
38413845 policy_factory = lambda : TensorDictModule (
3842- nn .Linear (3 , 1 ), in_keys = ["observation" ], out_keys = ["action" ]
3843- ). to ( device )
3846+ nn .Linear (3 , 1 , device = device ), in_keys = ["observation" ], out_keys = ["action" ]
3847+ )
38443848 policy = policy_factory ()
38453849 policy_weights = TensorDict .from_module (policy )
3850+ if weight_updater == "scheme_shared" :
3851+ kwargs = {"weight_schemes" : {"policy" : SharedMemWeightSyncScheme ()}}
3852+ elif weight_updater == "scheme_pip" :
3853+ kwargs = {"weight_schemes" : {"policy" : SharedMemWeightSyncScheme ()}}
3854+ elif weight_updater == "weight_updater" :
3855+ kwargs = {"weight_updater" : self .MPSWeightUpdaterBase (policy_weights , 2 )}
3856+ else :
3857+ raise NotImplementedError
38463858
38473859 collector = MultiSyncDataCollector (
38483860 create_env_fn = [env_maker , env_maker ],
@@ -3854,7 +3866,7 @@ def test_weight_update(self):
38543866 reset_at_each_iter = False ,
38553867 device = device ,
38563868 storing_device = "cpu" ,
3857- weight_updater = self . MPSWeightUpdaterBase ( policy_weights , 2 ) ,
3869+ ** kwargs ,
38583870 )
38593871
38603872 collector .update_policy_weights_ ()
0 commit comments