@@ -272,114 +272,6 @@ def test_no_weight_sync_scheme(self):
272272 transport .send_weights ("policy" , weights )
273273
274274
275- class TestMultiModelUpdates :
276- def test_multi_model_state_dict_updates (self ):
277- env = GymEnv ("CartPole-v1" )
278-
279- policy = TensorDictModule (
280- nn .Linear (
281- env .observation_spec ["observation" ].shape [- 1 ], env .action_spec .shape [- 1 ]
282- ),
283- in_keys = ["observation" ],
284- out_keys = ["action" ],
285- )
286-
287- value = TensorDictModule (
288- nn .Linear (env .observation_spec ["observation" ].shape [- 1 ], 1 ),
289- in_keys = ["observation" ],
290- out_keys = ["value" ],
291- )
292-
293- weight_sync_schemes = {
294- "policy" : MultiProcessWeightSyncScheme (strategy = "state_dict" ),
295- "value" : MultiProcessWeightSyncScheme (strategy = "state_dict" ),
296- }
297-
298- collector = SyncDataCollector (
299- create_env_fn = lambda : GymEnv ("CartPole-v1" ),
300- policy = policy ,
301- frames_per_batch = 64 ,
302- total_frames = 128 ,
303- weight_sync_schemes = weight_sync_schemes ,
304- )
305-
306- policy_weights = policy .state_dict ()
307- value_weights = value .state_dict ()
308-
309- with torch .no_grad ():
310- for key in policy_weights :
311- policy_weights [key ].fill_ (1.0 )
312- for key in value_weights :
313- value_weights [key ].fill_ (2.0 )
314-
315- collector .update_policy_weights_ (
316- weights_dict = {
317- "policy" : policy_weights ,
318- "value" : value_weights ,
319- }
320- )
321-
322- for data in collector :
323- assert data .numel () > 0
324- break
325-
326- collector .shutdown ()
327- env .close ()
328-
329- def test_multi_model_tensordict_updates (self ):
330- env = GymEnv ("CartPole-v1" )
331-
332- policy = TensorDictModule (
333- nn .Linear (
334- env .observation_spec ["observation" ].shape [- 1 ], env .action_spec .shape [- 1 ]
335- ),
336- in_keys = ["observation" ],
337- out_keys = ["action" ],
338- )
339-
340- value = TensorDictModule (
341- nn .Linear (env .observation_spec ["observation" ].shape [- 1 ], 1 ),
342- in_keys = ["observation" ],
343- out_keys = ["value" ],
344- )
345-
346- weight_sync_schemes = {
347- "policy" : MultiProcessWeightSyncScheme (strategy = "tensordict" ),
348- "value" : MultiProcessWeightSyncScheme (strategy = "tensordict" ),
349- }
350-
351- collector = SyncDataCollector (
352- create_env_fn = lambda : GymEnv ("CartPole-v1" ),
353- policy = policy ,
354- frames_per_batch = 64 ,
355- total_frames = 128 ,
356- weight_sync_schemes = weight_sync_schemes ,
357- )
358-
359- policy_weights = TensorDict .from_module (policy )
360- value_weights = TensorDict .from_module (value )
361-
362- with torch .no_grad ():
363- policy_weights ["module" ]["weight" ].fill_ (1.0 )
364- policy_weights ["module" ]["bias" ].fill_ (1.0 )
365- value_weights ["module" ]["weight" ].fill_ (2.0 )
366- value_weights ["module" ]["bias" ].fill_ (2.0 )
367-
368- collector .update_policy_weights_ (
369- weights_dict = {
370- "policy" : policy_weights ,
371- "value" : value_weights ,
372- }
373- )
374-
375- for data in collector :
376- assert data .numel () > 0
377- break
378-
379- collector .shutdown ()
380- env .close ()
381-
382-
383275class TestHelpers :
384276 def test_resolve_model_simple (self ):
385277 class Context :
0 commit comments