@@ -117,14 +117,19 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
117117
118118 Policy copy decision tree in Collectors.
119119
120- Weight Synchronization using Weight Update Schemes
121- --------------------------------------------------
120+ Weight Synchronization
121+ ----------------------
122122
123- RL pipelines are typically split in two big computational buckets: training, and inference.
124- While the inference pipeline sends data to the training one, the training pipeline needs to occasionally
125- synchronize its weights with the inference one.
126- In the most basic setting (fully synchronized data collection with traditional neural networks), the same weights are
127- used in both instances. From there, anything can happen:
123+ In reinforcement learning, the training pipeline is typically split into two computational phases:
124+ **inference ** (data collection) and **training ** (policy optimization). While the inference pipeline
125+ sends data to the training one, the training pipeline needs to periodically synchronize its weights
126+ with the inference workers to ensure they collect data using up-to-date policies.
127+
128+ Overview & Motivation
129+ ~~~~~~~~~~~~~~~~~~~~~
130+
131+ In the simplest setting, the same policy weights are used in both training and inference. However,
132+ real-world RL systems often face additional complexity:
128133
129134- In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named
130135 `DataCollectors ` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights
@@ -140,15 +145,222 @@ used in both instances. From there, anything can happen:
140145 asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach
141146 is to store the weights on some intermediary server and let the workers fetch them when necessary.
142147
143- TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight
144- transfer:
148+ Key Challenges
149+ ^^^^^^^^^^^^^^
150+
151+ Modern RL training often involves multiple models that need independent synchronization:
152+
153+ 1. **Multiple Models Per Collector **: A collector might need to update:
154+
155+ - The main policy network
156+ - A value network in a Ray actor within the replay buffer
157+ - Models embedded in the environment itself
158+ - Separate world models or auxiliary networks
159+
160+ 2. **Different Update Strategies **: Each model may require different synchronization approaches:
161+
162+ - Full state_dict transfer vs. TensorDict-based updates
163+ - Different transport mechanisms (multiprocessing pipes, shared memory, Ray object store, collective communication, RDMA, etc.)
164+ - Varied update frequencies
165+
166+ 3. **Worker-Agnostic Updates **: Some models (like those in shared Ray actors) shouldn't be tied to
167+ specific worker indices, requiring a more flexible update mechanism.
168+
169+ The Solution
170+ ^^^^^^^^^^^^
171+
172+ TorchRL addresses these challenges through a flexible, modular architecture built around four components:
173+
174+ - **WeightSyncScheme **: Defines *what * to synchronize and *how * (user-facing configuration)
175+ - **WeightSender **: Handles distributing weights from the main process to workers (internal)
176+ - **WeightReceiver **: Handles applying weights in worker processes (internal)
177+ - **TransportBackend **: Manages the actual communication layer (internal)
178+
179+ This design allows you to independently configure synchronization for multiple models,
180+ choose appropriate transport mechanisms, and swap strategies without rewriting your training code.
181+
182+ Architecture & Concepts
183+ ~~~~~~~~~~~~~~~~~~~~~~~
184+
185+ Component Roles
186+ ^^^^^^^^^^^^^^^
145187
146- - A `Sender ` class that somehow gets the weights (or a reference to them) and initializes the transfer;
147- - A `Receiver ` class that casts the weights to the destination module (policy or other utility module);
148- - A `Transport ` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else).
149- - A Scheme that defines what sender, receiver and transport have to be used and how to initialize them.
188+ The weight synchronization system separates concerns into four distinct layers:
150189
151- Each of these classes is detailed below.
190+ 1. **WeightSyncScheme ** (User-Facing)
191+
192+ This is your main configuration interface. You create scheme objects that define:
193+
194+ - The synchronization strategy (``"state_dict" `` or ``"tensordict" ``)
195+ - The transport mechanism (multiprocessing pipes, shared memory, Ray, RPC, etc.)
196+ - Additional options like auto-registration and timeout behavior
197+
198+ When working with collectors, you pass a dictionary mapping model IDs to schemes.
199+
200+ 2. **WeightSender ** (Internal)
201+
202+ Created by the scheme in the main training process. The sender:
203+
204+ - Holds a reference to the source model
205+ - Manages transport connections to all workers
206+ - Extracts weights using the configured strategy
207+ - Broadcasts weight updates across all transports
208+
209+ 3. **WeightReceiver ** (Internal)
210+
211+ Created by the scheme in each worker process. The receiver:
212+
213+ - Holds a reference to the destination model
214+ - Polls its transport for weight updates
215+ - Applies received weights using the configured strategy
216+ - Handles model registration and initialization
217+
218+ 4. **TransportBackend ** (Internal)
219+
220+ Implements the actual communication mechanism:
221+
222+ - ``MPTransport ``: Uses multiprocessing pipes
223+ - ``SharedMemTransport ``: Uses shared memory buffers (zero-copy)
224+ - ``RayTransport ``: Uses Ray's object store
225+ - ``RPCTransport ``: Uses PyTorch RPC
226+ - ``DistributedTransport ``: Uses collective communication (NCCL, Gloo, MPI)
227+
228+ Initialization Phase
229+ ^^^^^^^^^^^^^^^^^^^^
230+
231+ When you create a collector with weight sync schemes, the following initialization occurs:
232+
233+ .. aafig ::
234+ :aspect: 60
235+ :scale: 130
236+ :proportional:
237+
238+ INITIALIZATION PHASE
239+ ====================
240+
241+ WeightSyncScheme
242+ +------------------+
243+ | |
244+ | Configuration: |
245+ | - strategy |
246+ | - transport_type |
247+ | |
248+ +--------+---------+
249+ |
250+ +------------+-------------+
251+ | |
252+ creates creates
253+ | |
254+ v v
255+ Main Process Worker Process
256+ +--------------+ +---------------+
257+ | WeightSender | | WeightReceiver|
258+ | | | |
259+ | - strategy | | - strategy |
260+ | - transports | | - transport |
261+ | - model_ref | | - model_ref |
262+ | | | |
263+ | Registers: | | Registers: |
264+ | - model | | - model |
265+ | - workers | | - transport |
266+ +--------------+ +---------------+
267+ | |
268+ | Transport Layer |
269+ | +----------------+ |
270+ +-->+ MPTransport |<------+
271+ | | (pipes) | |
272+ | +----------------+ |
273+ | +----------------+ |
274+ +-->+ SharedMemTrans |<------+
275+ | | (shared mem) | |
276+ | +----------------+ |
277+ | +----------------+ |
278+ +-->+ RayTransport |<------+
279+ | (Ray store) |
280+ +----------------+
281+
282+ The scheme creates a sender in the main process and a receiver in each worker, then establishes
283+ transport connections between them.
284+
285+ Synchronization Phase
286+ ^^^^^^^^^^^^^^^^^^^^^
287+
288+ When you call ``collector.update_policy_weights_() ``, the weight synchronization proceeds as follows:
289+
290+ .. aafig ::
291+ :aspect: 60
292+ :scale: 130
293+ :proportional:
294+
295+ SYNCHRONIZATION PHASE
296+ =====================
297+
298+ Main Process Worker Process
299+
300+ +-------------------+ +-------------------+
301+ | WeightSender | | WeightReceiver |
302+ | | | |
303+ | 1. Extract | | 4. Poll transport |
304+ | weights from | | for weights |
305+ | model using | | |
306+ | strategy | | |
307+ | | 2. Send via | |
308+ | +-------------+ | Transport | +--------------+ |
309+ | | Strategy | | +------------+ | | Strategy | |
310+ | | extract() | | | | | | apply() | |
311+ | +-------------+ +----+ Transport +-------->+ +--------------+ |
312+ | | | | | | | |
313+ | v | +------------+ | v |
314+ | +-------------+ | | +--------------+ |
315+ | | Model | | | | Model | |
316+ | | (source) | | 3. Ack (optional) | | (dest) | |
317+ | +-------------+ | <-----------------------+ | +--------------+ |
318+ | | | |
319+ +-------------------+ | 5. Apply weights |
320+ | to model using |
321+ | strategy |
322+ +-------------------+
323+
324+ 1. **Extract **: Sender extracts weights from the source model (state_dict or TensorDict)
325+ 2. **Send **: Sender broadcasts weights through all registered transports
326+ 3. **Acknowledge ** (optional): Some transports send acknowledgment back
327+ 4. **Poll **: Receiver checks its transport for new weights
328+ 5. **Apply **: Receiver applies weights to the destination model
329+
330+ Multi-Model Synchronization
331+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
332+
333+ One of the key features is support for synchronizing multiple models independently:
334+
335+ .. aafig ::
336+ :aspect: 60
337+ :scale: 130
338+ :proportional:
339+
340+ Main Process Worker Process 1 Worker Process 2
341+
342+ +-----------------+ +---------------+ +---------------+
343+ | Collector | | Collector | | Collector |
344+ | | | | | |
345+ | Models: | | Models: | | Models: |
346+ | +----------+ | | +--------+ | | +--------+ |
347+ | | Policy A | | | |Policy A| | | |Policy A| |
348+ | +----------+ | | +--------+ | | +--------+ |
349+ | +----------+ | | +--------+ | | +--------+ |
350+ | | Model B | | | |Model B| | | |Model B| |
351+ | +----------+ | | +--------+ | | +--------+ |
352+ | | | | | |
353+ | Weight Senders: | | Weight | | Weight |
354+ | +----------+ | | Receivers: | | Receivers: |
355+ | | Sender A +---+------------+->Receiver A | | Receiver A |
356+ | +----------+ | | | | |
357+ | +----------+ | | +--------+ | | +--------+ |
358+ | | Sender B +---+------------+->Receiver B | | Receiver B |
359+ | +----------+ | Pipes | | Pipes | |
360+ +-----------------+ +-------+-------+ +-------+-------+
361+
362+ Each model gets its own sender/receiver pair, allowing independent synchronization frequencies,
363+ different transport mechanisms per model, and model-specific strategies.
152364
153365Usage Examples
154366~~~~~~~~~~~~~~
@@ -301,32 +513,55 @@ across multiple inference workers:
301513 dictionaries, while ``"tensordict" `` uses TensorDict format which can be more efficient for structured
302514 models and supports advanced features like lazy initialization.
303515
304- Weight Senders
305- ~~~~~~~~~~~~~~
516+ API Reference
517+ ~~~~~~~~~~~~~
518+
519+ The weight synchronization system provides both user-facing configuration classes and internal
520+ implementation classes that are automatically managed by the collectors.
521+
522+ Schemes (User-Facing Configuration)
523+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
524+
525+ These are the main classes you'll use to configure weight synchronization. Pass them in the
526+ ``weight_sync_schemes `` dictionary when creating collectors.
306527
307528.. currentmodule :: torchrl.weight_update
308529
309530.. autosummary ::
310531 :toctree: generated/
311532 :template: rl_template.rst
312533
313- WeightSender
314- RayModuleTransformSender
534+ WeightSyncScheme
535+ MultiProcessWeightSyncScheme
536+ SharedMemWeightSyncScheme
537+ NoWeightSyncScheme
538+ RayWeightSyncScheme
539+ RayModuleTransformScheme
540+ RPCWeightSyncScheme
541+ DistributedWeightSyncScheme
542+
543+ Senders and Receivers (Internal)
544+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
315545
316- Weight Receivers
317- ~~~~~~~~~~~~~~~~
546+ These classes are automatically created and managed by the schemes. You typically don't need
547+ to interact with them directly.
318548
319549.. currentmodule :: torchrl.weight_update
320550
321551.. autosummary ::
322552 :toctree: generated/
323553 :template: rl_template.rst
324554
555+ WeightSender
325556 WeightReceiver
557+ RayModuleTransformSender
326558 RayModuleTransformReceiver
327559
328- Transports
329- ~~~~~~~~~~
560+ Transport Backends (Internal)
561+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
562+
563+ Transport classes handle the actual communication between processes. They are automatically
564+ selected and configured by the schemes.
330565
331566.. currentmodule :: torchrl.weight_update
332567
@@ -342,24 +577,6 @@ Transports
342577 RPCTransport
343578 DistributedTransport
344579
345- Schemes
346- ~~~~~~~
347-
348- .. currentmodule :: torchrl.weight_update
349-
350- .. autosummary ::
351- :toctree: generated/
352- :template: rl_template.rst
353-
354- WeightSyncScheme
355- MultiProcessWeightSyncScheme
356- SharedMemWeightSyncScheme
357- NoWeightSyncScheme
358- RayWeightSyncScheme
359- RayModuleTransformScheme
360- RPCWeightSyncScheme
361- DistributedWeightSyncScheme
362-
363580Legacy: Weight Synchronization in Distributed Environments
364581----------------------------------------------------------
365582
0 commit comments