Skip to content

Commit ce9f29e

Browse files
committed
Update
[ghstack-poisoned]
2 parents 48211cb + ae5f22b commit ce9f29e

File tree

1 file changed

+0
-108
lines changed

1 file changed

+0
-108
lines changed

test/test_weightsync.py

Lines changed: 0 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -272,114 +272,6 @@ def test_no_weight_sync_scheme(self):
272272
transport.send_weights("policy", weights)
273273

274274

275-
class TestMultiModelUpdates:
276-
def test_multi_model_state_dict_updates(self):
277-
env = GymEnv("CartPole-v1")
278-
279-
policy = TensorDictModule(
280-
nn.Linear(
281-
env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1]
282-
),
283-
in_keys=["observation"],
284-
out_keys=["action"],
285-
)
286-
287-
value = TensorDictModule(
288-
nn.Linear(env.observation_spec["observation"].shape[-1], 1),
289-
in_keys=["observation"],
290-
out_keys=["value"],
291-
)
292-
293-
weight_sync_schemes = {
294-
"policy": MultiProcessWeightSyncScheme(strategy="state_dict"),
295-
"value": MultiProcessWeightSyncScheme(strategy="state_dict"),
296-
}
297-
298-
collector = SyncDataCollector(
299-
create_env_fn=lambda: GymEnv("CartPole-v1"),
300-
policy=policy,
301-
frames_per_batch=64,
302-
total_frames=128,
303-
weight_sync_schemes=weight_sync_schemes,
304-
)
305-
306-
policy_weights = policy.state_dict()
307-
value_weights = value.state_dict()
308-
309-
with torch.no_grad():
310-
for key in policy_weights:
311-
policy_weights[key].fill_(1.0)
312-
for key in value_weights:
313-
value_weights[key].fill_(2.0)
314-
315-
collector.update_policy_weights_(
316-
weights_dict={
317-
"policy": policy_weights,
318-
"value": value_weights,
319-
}
320-
)
321-
322-
for data in collector:
323-
assert data.numel() > 0
324-
break
325-
326-
collector.shutdown()
327-
env.close()
328-
329-
def test_multi_model_tensordict_updates(self):
330-
env = GymEnv("CartPole-v1")
331-
332-
policy = TensorDictModule(
333-
nn.Linear(
334-
env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1]
335-
),
336-
in_keys=["observation"],
337-
out_keys=["action"],
338-
)
339-
340-
value = TensorDictModule(
341-
nn.Linear(env.observation_spec["observation"].shape[-1], 1),
342-
in_keys=["observation"],
343-
out_keys=["value"],
344-
)
345-
346-
weight_sync_schemes = {
347-
"policy": MultiProcessWeightSyncScheme(strategy="tensordict"),
348-
"value": MultiProcessWeightSyncScheme(strategy="tensordict"),
349-
}
350-
351-
collector = SyncDataCollector(
352-
create_env_fn=lambda: GymEnv("CartPole-v1"),
353-
policy=policy,
354-
frames_per_batch=64,
355-
total_frames=128,
356-
weight_sync_schemes=weight_sync_schemes,
357-
)
358-
359-
policy_weights = TensorDict.from_module(policy)
360-
value_weights = TensorDict.from_module(value)
361-
362-
with torch.no_grad():
363-
policy_weights["module"]["weight"].fill_(1.0)
364-
policy_weights["module"]["bias"].fill_(1.0)
365-
value_weights["module"]["weight"].fill_(2.0)
366-
value_weights["module"]["bias"].fill_(2.0)
367-
368-
collector.update_policy_weights_(
369-
weights_dict={
370-
"policy": policy_weights,
371-
"value": value_weights,
372-
}
373-
)
374-
375-
for data in collector:
376-
assert data.numel() > 0
377-
break
378-
379-
collector.shutdown()
380-
env.close()
381-
382-
383275
class TestHelpers:
384276
def test_resolve_model_simple(self):
385277
class Context:

0 commit comments

Comments
 (0)