1616
1717import torch
1818import torch .nn as nn
19- from torch import multiprocessing as mp
2019from tensordict import TensorDict
20+ from torch import multiprocessing as mp
2121from torchrl .weight_update import (
2222 MultiProcessWeightSyncScheme ,
2323 SharedMemWeightSyncScheme ,
2727def worker_process_mp (child_pipe , model_state ):
2828 """Worker process that receives weights via multiprocessing pipe."""
2929 print ("Worker: Starting..." )
30-
30+
3131 # Create a policy on the worker side
3232 policy = nn .Linear (4 , 2 )
3333 with torch .no_grad ():
3434 policy .weight .fill_ (0.0 )
3535 policy .bias .fill_ (0.0 )
36-
36+
3737 # Create receiver and register the policy
3838 scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
3939 receiver = scheme .create_receiver ()
4040 receiver .register_model (policy )
4141 receiver .register_worker_transport (child_pipe )
42-
42+
4343 print (f"Worker: Before update - weight sum: { policy .weight .sum ().item ():.4f} " )
44-
44+
4545 # Receive and apply weights
4646 result = receiver ._transport .receive_weights (timeout = 5.0 )
4747 if result is not None :
@@ -50,19 +50,19 @@ def worker_process_mp(child_pipe, model_state):
5050 print (f"Worker: After update - weight sum: { policy .weight .sum ().item ():.4f} " )
5151 else :
5252 print ("Worker: No weights received" )
53-
53+
5454 # Store final state for verification
55- model_state [' weight_sum' ] = policy .weight .sum ().item ()
56- model_state [' bias_sum' ] = policy .bias .sum ().item ()
55+ model_state [" weight_sum" ] = policy .weight .sum ().item ()
56+ model_state [" bias_sum" ] = policy .bias .sum ().item ()
5757
5858
5959def worker_process_shared_mem (child_pipe , model_state ):
6060 """Worker process that receives shared memory buffer reference."""
6161 print ("SharedMem Worker: Starting..." )
62-
62+
6363 # Create a policy on the worker side
6464 policy = nn .Linear (4 , 2 )
65-
65+
6666 # Wait for shared memory buffer registration
6767 if child_pipe .poll (timeout = 10.0 ):
6868 data , msg = child_pipe .recv ()
@@ -73,129 +73,135 @@ def worker_process_shared_mem(child_pipe, model_state):
7373 shared_weights .to_module (policy )
7474 # Send acknowledgment
7575 child_pipe .send ((None , "registered" ))
76-
76+
7777 # Small delay to ensure main process updates shared memory
7878 import time
79+
7980 time .sleep (0.5 )
80-
81+
8182 print (f"SharedMem Worker: weight sum: { policy .weight .sum ().item ():.4f} " )
82-
83+
8384 # Store final state for verification
84- model_state [' weight_sum' ] = policy .weight .sum ().item ()
85- model_state [' bias_sum' ] = policy .bias .sum ().item ()
85+ model_state [" weight_sum" ] = policy .weight .sum ().item ()
86+ model_state [" bias_sum" ] = policy .bias .sum ().item ()
8687
8788
8889def example_multiprocess_sync ():
8990 """Example 1: Multiprocess weight synchronization with state_dict."""
90- print ("\n " + "=" * 70 )
91+ print ("\n " + "=" * 70 )
9192 print ("Example 1: Multiprocess Weight Synchronization" )
92- print ("=" * 70 )
93-
93+ print ("=" * 70 )
94+
9495 # Create a simple policy on main process
9596 policy = nn .Linear (4 , 2 )
9697 with torch .no_grad ():
9798 policy .weight .fill_ (1.0 )
9899 policy .bias .fill_ (0.5 )
99-
100+
100101 print (f"Main: Policy weight sum: { policy .weight .sum ().item ():.4f} " )
101-
102+
102103 # Create scheme and sender
103104 scheme = MultiProcessWeightSyncScheme (strategy = "state_dict" )
104105 sender = scheme .create_sender ()
105-
106+
106107 # Create pipe for communication
107108 parent_pipe , child_pipe = mp .Pipe ()
108109 sender .register_worker (worker_idx = 0 , pipe_or_context = parent_pipe )
109-
110+
110111 # Start worker process
111112 manager = mp .Manager ()
112113 model_state = manager .dict ()
113114 process = mp .Process (target = worker_process_mp , args = (child_pipe , model_state ))
114115 process .start ()
115-
116+
116117 # Send weights to worker
117118 weights = policy .state_dict ()
118119 print ("Main: Sending weights to worker..." )
119120 sender .update_weights (weights )
120-
121+
121122 # Wait for worker to complete
122123 process .join (timeout = 10.0 )
123-
124+
124125 if process .is_alive ():
125126 print ("Warning: Worker process did not terminate in time" )
126127 process .terminate ()
127128 else :
128- print (f"Main: Worker completed. Worker's weight sum: { model_state ['weight_sum' ]:.4f} " )
129+ print (
130+ f"Main: Worker completed. Worker's weight sum: { model_state ['weight_sum' ]:.4f} "
131+ )
129132 print (f"✓ Weight synchronization successful!" )
130133
131134
132135def example_shared_memory_sync ():
133136 """Example 2: Shared memory weight synchronization."""
134- print ("\n " + "=" * 70 )
137+ print ("\n " + "=" * 70 )
135138 print ("Example 2: Shared Memory Weight Synchronization" )
136- print ("=" * 70 )
137-
139+ print ("=" * 70 )
140+
138141 # Create a simple policy
139142 policy = nn .Linear (4 , 2 )
140-
143+
141144 # Create shared memory scheme with auto-registration
142145 scheme = SharedMemWeightSyncScheme (strategy = "tensordict" , auto_register = True )
143146 sender = scheme .create_sender ()
144-
147+
145148 # Create pipe for lazy registration
146149 parent_pipe , child_pipe = mp .Pipe ()
147150 sender .register_worker (worker_idx = 0 , pipe_or_context = parent_pipe )
148-
151+
149152 # Start worker process
150153 manager = mp .Manager ()
151154 model_state = manager .dict ()
152- process = mp .Process (target = worker_process_shared_mem , args = (child_pipe , model_state ))
155+ process = mp .Process (
156+ target = worker_process_shared_mem , args = (child_pipe , model_state )
157+ )
153158 process .start ()
154-
159+
155160 # Send weights (automatically creates shared buffer on first send)
156161 weights_td = TensorDict .from_module (policy )
157162 with torch .no_grad ():
158163 weights_td ["weight" ].fill_ (2.0 )
159164 weights_td ["bias" ].fill_ (1.0 )
160-
165+
161166 print (f"Main: Sending weights via shared memory..." )
162167 sender .update_weights (weights_td )
163-
168+
164169 # Workers automatically see updates via shared memory!
165170 print ("Main: Weights are now in shared memory, workers can access them" )
166-
171+
167172 # Wait for worker to complete
168173 process .join (timeout = 10.0 )
169-
174+
170175 if process .is_alive ():
171176 print ("Warning: Worker process did not terminate in time" )
172177 process .terminate ()
173178 else :
174- print (f"Main: Worker completed. Worker's weight sum: { model_state ['weight_sum' ]:.4f} " )
179+ print (
180+ f"Main: Worker completed. Worker's weight sum: { model_state ['weight_sum' ]:.4f} "
181+ )
175182 print (f"✓ Shared memory synchronization successful!" )
176183
177184
178185def main ():
179186 """Run all examples."""
180- print ("\n " + "=" * 70 )
187+ print ("\n " + "=" * 70 )
181188 print ("Weight Synchronization Schemes - Standalone Usage Examples" )
182- print ("=" * 70 )
183-
189+ print ("=" * 70 )
190+
184191 # Set multiprocessing start method
185192 try :
186- mp .set_start_method (' spawn' )
193+ mp .set_start_method (" spawn" )
187194 except RuntimeError :
188195 pass # Already set
189-
196+
190197 # Run examples
191198 example_multiprocess_sync ()
192199 example_shared_memory_sync ()
193-
194- print ("\n " + "=" * 70 )
200+
201+ print ("\n " + "=" * 70 )
195202 print ("All examples completed successfully!" )
196- print ("=" * 70 + "\n " )
203+ print ("=" * 70 + "\n " )
197204
198205
199206if __name__ == "__main__" :
200207 main ()
201-
0 commit comments