Skip to content

Commit 48211cb

Browse files
committed
Update
[ghstack-poisoned]
2 parents 1dbb367 + e0ea53b commit 48211cb

File tree

4 files changed

+91
-197
lines changed

4 files changed

+91
-197
lines changed

examples/collectors/weight_sync_collectors.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn as nn
1818
from tensordict import TensorDict
1919
from tensordict.nn import TensorDictModule
20-
from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector
20+
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
2121
from torchrl.envs import GymEnv
2222
from torchrl.weight_update import (
2323
MultiProcessWeightSyncScheme,
@@ -27,25 +27,24 @@
2727

2828
def example_single_collector_multiprocess():
2929
"""Example 1: Single collector with multiprocess scheme."""
30-
print("\n" + "="*70)
30+
print("\n" + "=" * 70)
3131
print("Example 1: Single Collector with Multiprocess Scheme")
32-
print("="*70)
33-
32+
print("=" * 70)
33+
3434
# Create environment and policy
3535
env = GymEnv("CartPole-v1")
3636
policy = TensorDictModule(
3737
nn.Linear(
38-
env.observation_spec["observation"].shape[-1],
39-
env.action_spec.shape[-1]
38+
env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1]
4039
),
4140
in_keys=["observation"],
4241
out_keys=["action"],
4342
)
4443
env.close()
45-
44+
4645
# Create weight sync scheme
4746
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
48-
47+
4948
print("Creating collector with multiprocess weight sync...")
5049
collector = SyncDataCollector(
5150
create_env_fn=lambda: GymEnv("CartPole-v1"),
@@ -54,46 +53,45 @@ def example_single_collector_multiprocess():
5453
total_frames=200,
5554
weight_sync_schemes={"policy": scheme},
5655
)
57-
56+
5857
# Collect data and update weights periodically
5958
print("Collecting data...")
6059
for i, data in enumerate(collector):
6160
print(f"Iteration {i}: Collected {data.numel()} transitions")
62-
61+
6362
# Update policy weights every 2 iterations
6463
if i % 2 == 0:
6564
new_weights = policy.state_dict()
6665
collector.update_policy_weights_(new_weights)
6766
print(" → Updated policy weights")
68-
67+
6968
if i >= 2: # Just run a few iterations for demo
7069
break
71-
70+
7271
collector.shutdown()
7372
print("✓ Single collector example completed!\n")
7473

7574

7675
def example_multi_collector_shared_memory():
7776
"""Example 2: Multiple collectors with shared memory."""
78-
print("\n" + "="*70)
77+
print("\n" + "=" * 70)
7978
print("Example 2: Multiple Collectors with Shared Memory")
80-
print("="*70)
81-
79+
print("=" * 70)
80+
8281
# Create environment and policy
8382
env = GymEnv("CartPole-v1")
8483
policy = TensorDictModule(
8584
nn.Linear(
86-
env.observation_spec["observation"].shape[-1],
87-
env.action_spec.shape[-1]
85+
env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1]
8886
),
8987
in_keys=["observation"],
9088
out_keys=["action"],
9189
)
9290
env.close()
93-
91+
9492
# Shared memory is more efficient for frequent updates
9593
scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
96-
94+
9795
print("Creating multi-collector with shared memory...")
9896
collector = MultiSyncDataCollector(
9997
create_env_fn=[
@@ -106,49 +104,51 @@ def example_multi_collector_shared_memory():
106104
total_frames=400,
107105
weight_sync_schemes={"policy": scheme},
108106
)
109-
107+
110108
# Workers automatically see weight updates via shared memory
111109
print("Collecting data...")
112110
for i, data in enumerate(collector):
113111
print(f"Iteration {i}: Collected {data.numel()} transitions")
114-
112+
115113
# Update weights frequently (shared memory makes this very fast)
116114
collector.update_policy_weights_(TensorDict.from_module(policy))
117115
print(" → Updated policy weights via shared memory")
118-
116+
119117
if i >= 1: # Just run a couple iterations for demo
120118
break
121-
119+
122120
collector.shutdown()
123121
print("✓ Multi-collector with shared memory example completed!\n")
124122

125123

126124
def main():
127125
"""Run all examples."""
128-
print("\n" + "="*70)
126+
print("\n" + "=" * 70)
129127
print("Weight Synchronization Schemes - Collector Integration Examples")
130-
print("="*70)
131-
128+
print("=" * 70)
129+
132130
# Set multiprocessing start method
133131
import torch.multiprocessing as mp
132+
134133
try:
135-
mp.set_start_method('spawn')
134+
mp.set_start_method("spawn")
136135
except RuntimeError:
137136
pass # Already set
138-
137+
139138
# Run examples
140139
example_single_collector_multiprocess()
141140
example_multi_collector_shared_memory()
142-
143-
print("\n" + "="*70)
141+
142+
print("\n" + "=" * 70)
144143
print("All examples completed successfully!")
145-
print("="*70)
144+
print("=" * 70)
146145
print("\nKey takeaways:")
147146
print(" • MultiProcessWeightSyncScheme: Good for general multiprocess scenarios")
148-
print(" • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers")
149-
print("="*70 + "\n")
147+
print(
148+
" • SharedMemWeightSyncScheme: Fast zero-copy updates for same-machine workers"
149+
)
150+
print("=" * 70 + "\n")
150151

151152

152153
if __name__ == "__main__":
153154
main()
154-

examples/collectors/weight_sync_standalone.py

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
import torch
1818
import torch.nn as nn
19-
from torch import multiprocessing as mp
2019
from tensordict import TensorDict
20+
from torch import multiprocessing as mp
2121
from torchrl.weight_update import (
2222
MultiProcessWeightSyncScheme,
2323
SharedMemWeightSyncScheme,
@@ -27,21 +27,21 @@
2727
def worker_process_mp(child_pipe, model_state):
2828
"""Worker process that receives weights via multiprocessing pipe."""
2929
print("Worker: Starting...")
30-
30+
3131
# Create a policy on the worker side
3232
policy = nn.Linear(4, 2)
3333
with torch.no_grad():
3434
policy.weight.fill_(0.0)
3535
policy.bias.fill_(0.0)
36-
36+
3737
# Create receiver and register the policy
3838
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
3939
receiver = scheme.create_receiver()
4040
receiver.register_model(policy)
4141
receiver.register_worker_transport(child_pipe)
42-
42+
4343
print(f"Worker: Before update - weight sum: {policy.weight.sum().item():.4f}")
44-
44+
4545
# Receive and apply weights
4646
result = receiver._transport.receive_weights(timeout=5.0)
4747
if result is not None:
@@ -50,19 +50,19 @@ def worker_process_mp(child_pipe, model_state):
5050
print(f"Worker: After update - weight sum: {policy.weight.sum().item():.4f}")
5151
else:
5252
print("Worker: No weights received")
53-
53+
5454
# Store final state for verification
55-
model_state['weight_sum'] = policy.weight.sum().item()
56-
model_state['bias_sum'] = policy.bias.sum().item()
55+
model_state["weight_sum"] = policy.weight.sum().item()
56+
model_state["bias_sum"] = policy.bias.sum().item()
5757

5858

5959
def worker_process_shared_mem(child_pipe, model_state):
6060
"""Worker process that receives shared memory buffer reference."""
6161
print("SharedMem Worker: Starting...")
62-
62+
6363
# Create a policy on the worker side
6464
policy = nn.Linear(4, 2)
65-
65+
6666
# Wait for shared memory buffer registration
6767
if child_pipe.poll(timeout=10.0):
6868
data, msg = child_pipe.recv()
@@ -73,129 +73,135 @@ def worker_process_shared_mem(child_pipe, model_state):
7373
shared_weights.to_module(policy)
7474
# Send acknowledgment
7575
child_pipe.send((None, "registered"))
76-
76+
7777
# Small delay to ensure main process updates shared memory
7878
import time
79+
7980
time.sleep(0.5)
80-
81+
8182
print(f"SharedMem Worker: weight sum: {policy.weight.sum().item():.4f}")
82-
83+
8384
# Store final state for verification
84-
model_state['weight_sum'] = policy.weight.sum().item()
85-
model_state['bias_sum'] = policy.bias.sum().item()
85+
model_state["weight_sum"] = policy.weight.sum().item()
86+
model_state["bias_sum"] = policy.bias.sum().item()
8687

8788

8889
def example_multiprocess_sync():
8990
"""Example 1: Multiprocess weight synchronization with state_dict."""
90-
print("\n" + "="*70)
91+
print("\n" + "=" * 70)
9192
print("Example 1: Multiprocess Weight Synchronization")
92-
print("="*70)
93-
93+
print("=" * 70)
94+
9495
# Create a simple policy on main process
9596
policy = nn.Linear(4, 2)
9697
with torch.no_grad():
9798
policy.weight.fill_(1.0)
9899
policy.bias.fill_(0.5)
99-
100+
100101
print(f"Main: Policy weight sum: {policy.weight.sum().item():.4f}")
101-
102+
102103
# Create scheme and sender
103104
scheme = MultiProcessWeightSyncScheme(strategy="state_dict")
104105
sender = scheme.create_sender()
105-
106+
106107
# Create pipe for communication
107108
parent_pipe, child_pipe = mp.Pipe()
108109
sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe)
109-
110+
110111
# Start worker process
111112
manager = mp.Manager()
112113
model_state = manager.dict()
113114
process = mp.Process(target=worker_process_mp, args=(child_pipe, model_state))
114115
process.start()
115-
116+
116117
# Send weights to worker
117118
weights = policy.state_dict()
118119
print("Main: Sending weights to worker...")
119120
sender.update_weights(weights)
120-
121+
121122
# Wait for worker to complete
122123
process.join(timeout=10.0)
123-
124+
124125
if process.is_alive():
125126
print("Warning: Worker process did not terminate in time")
126127
process.terminate()
127128
else:
128-
print(f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}")
129+
print(
130+
f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}"
131+
)
129132
print(f"✓ Weight synchronization successful!")
130133

131134

132135
def example_shared_memory_sync():
133136
"""Example 2: Shared memory weight synchronization."""
134-
print("\n" + "="*70)
137+
print("\n" + "=" * 70)
135138
print("Example 2: Shared Memory Weight Synchronization")
136-
print("="*70)
137-
139+
print("=" * 70)
140+
138141
# Create a simple policy
139142
policy = nn.Linear(4, 2)
140-
143+
141144
# Create shared memory scheme with auto-registration
142145
scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
143146
sender = scheme.create_sender()
144-
147+
145148
# Create pipe for lazy registration
146149
parent_pipe, child_pipe = mp.Pipe()
147150
sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe)
148-
151+
149152
# Start worker process
150153
manager = mp.Manager()
151154
model_state = manager.dict()
152-
process = mp.Process(target=worker_process_shared_mem, args=(child_pipe, model_state))
155+
process = mp.Process(
156+
target=worker_process_shared_mem, args=(child_pipe, model_state)
157+
)
153158
process.start()
154-
159+
155160
# Send weights (automatically creates shared buffer on first send)
156161
weights_td = TensorDict.from_module(policy)
157162
with torch.no_grad():
158163
weights_td["weight"].fill_(2.0)
159164
weights_td["bias"].fill_(1.0)
160-
165+
161166
print(f"Main: Sending weights via shared memory...")
162167
sender.update_weights(weights_td)
163-
168+
164169
# Workers automatically see updates via shared memory!
165170
print("Main: Weights are now in shared memory, workers can access them")
166-
171+
167172
# Wait for worker to complete
168173
process.join(timeout=10.0)
169-
174+
170175
if process.is_alive():
171176
print("Warning: Worker process did not terminate in time")
172177
process.terminate()
173178
else:
174-
print(f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}")
179+
print(
180+
f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}"
181+
)
175182
print(f"✓ Shared memory synchronization successful!")
176183

177184

178185
def main():
179186
"""Run all examples."""
180-
print("\n" + "="*70)
187+
print("\n" + "=" * 70)
181188
print("Weight Synchronization Schemes - Standalone Usage Examples")
182-
print("="*70)
183-
189+
print("=" * 70)
190+
184191
# Set multiprocessing start method
185192
try:
186-
mp.set_start_method('spawn')
193+
mp.set_start_method("spawn")
187194
except RuntimeError:
188195
pass # Already set
189-
196+
190197
# Run examples
191198
example_multiprocess_sync()
192199
example_shared_memory_sync()
193-
194-
print("\n" + "="*70)
200+
201+
print("\n" + "=" * 70)
195202
print("All examples completed successfully!")
196-
print("="*70 + "\n")
203+
print("=" * 70 + "\n")
197204

198205

199206
if __name__ == "__main__":
200207
main()
201-

0 commit comments

Comments
 (0)