1919
2020from rapidsmpf .buffer .buffer import MemoryType
2121from rapidsmpf .buffer .resource import BufferResource , LimitAvailableMemory
22+ from rapidsmpf .buffer .rmm_fallback_resource import RmmFallbackResource
2223from rapidsmpf .buffer .spill_collection import SpillCollection
2324from rapidsmpf .communicator .ucxx import barrier , get_root_ucxx_address , new_communicator
2425from rapidsmpf .integrations .dask import _compat
@@ -187,6 +188,7 @@ def rmpf_worker_setup(
187188 * ,
188189 spill_device : float ,
189190 periodic_spill_check : float ,
191+ oom_protection : bool ,
190192 enable_statistics : bool ,
191193) -> None :
192194 """
@@ -204,6 +206,9 @@ def rmpf_worker_setup(
204206 by the buffer resource. The value of ``periodic_spill_check`` is used as
205207 the pause between checks (in seconds). If None, no periodic spill check
206208 is performed.
209+ oom_protection
210+ Enable out-of-memory protection by using managed memory when the device
211+ memory pool raises OOM errors.
207212 enable_statistics
208213 Whether to track shuffler statistics.
209214
@@ -236,9 +241,13 @@ def rmpf_worker_setup(
236241 assert ctx .comm is not None
237242 ctx .progress_thread = ProgressThread (ctx .comm , ctx .statistics )
238243
244+ mr = rmm .mr .get_current_device_resource ()
245+ if oom_protection :
246+ mr = RmmFallbackResource (mr , rmm .mr .ManagedMemoryResource ())
247+
239248 # Setup a buffer_resource.
240249 # Wrap the current RMM resource in statistics adaptor.
241- mr = rmm .mr .StatisticsResourceAdaptor (rmm . mr . get_current_device_resource () )
250+ mr = rmm .mr .StatisticsResourceAdaptor (mr )
242251 rmm .mr .set_current_device_resource (mr )
243252 total_memory = rmm .mr .available_device_memory ()[1 ]
244253 memory_available = {
@@ -307,6 +316,7 @@ def bootstrap_dask_cluster(
307316 * ,
308317 spill_device : float = 0.50 ,
309318 periodic_spill_check : float | None = 1e-3 ,
319+ oom_protection : bool = True ,
310320 enable_statistics : bool = True ,
311321) -> None :
312322 """
@@ -324,6 +334,9 @@ def bootstrap_dask_cluster(
324334 by the buffer resource. The value of ``periodic_spill_check`` is used as
325335 the pause between checks (in seconds). If None, no periodic spill
326336 check is performed.
337+ oom_protection
338+ Enable out-of-memory protection by using managed memory when the device
339+ memory pool raises OOM errors.
327340 enable_statistics
328341 Whether to track shuffler statistics.
329342
@@ -383,6 +396,7 @@ def bootstrap_dask_cluster(
383396 rmpf_worker_setup ,
384397 spill_device = spill_device ,
385398 periodic_spill_check = periodic_spill_check ,
399+ oom_protection = oom_protection ,
386400 enable_statistics = enable_statistics ,
387401 )
388402
0 commit comments