Skip to content

Commit 2a9a7c6

Browse files
committed
bootstrap_dask_cluster: oom_protection
1 parent c33c48c commit 2a9a7c6

File tree

1 file changed

+15
-1
lines changed
  • python/rapidsmpf/rapidsmpf/integrations/dask

1 file changed

+15
-1
lines changed

python/rapidsmpf/rapidsmpf/integrations/dask/core.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from rapidsmpf.buffer.buffer import MemoryType
2121
from rapidsmpf.buffer.resource import BufferResource, LimitAvailableMemory
22+
from rapidsmpf.buffer.rmm_fallback_resource import RmmFallbackResource
2223
from rapidsmpf.buffer.spill_collection import SpillCollection
2324
from rapidsmpf.communicator.ucxx import barrier, get_root_ucxx_address, new_communicator
2425
from rapidsmpf.integrations.dask import _compat
@@ -187,6 +188,7 @@ def rmpf_worker_setup(
187188
*,
188189
spill_device: float,
189190
periodic_spill_check: float,
191+
oom_protection: bool,
190192
enable_statistics: bool,
191193
) -> None:
192194
"""
@@ -204,6 +206,9 @@ def rmpf_worker_setup(
204206
by the buffer resource. The value of ``periodic_spill_check`` is used as
205207
the pause between checks (in seconds). If None, no periodic spill check
206208
is performed.
209+
oom_protection
210+
Enable out-of-memory protection by using managed memory when the device
211+
memory pool raises OOM errors.
207212
enable_statistics
208213
Whether to track shuffler statistics.
209214
@@ -236,9 +241,13 @@ def rmpf_worker_setup(
236241
assert ctx.comm is not None
237242
ctx.progress_thread = ProgressThread(ctx.comm, ctx.statistics)
238243

244+
mr = rmm.mr.get_current_device_resource()
245+
if oom_protection:
246+
mr = RmmFallbackResource(mr, rmm.mr.ManagedMemoryResource())
247+
239248
# Setup a buffer_resource.
240249
# Wrap the current RMM resource in statistics adaptor.
241-
mr = rmm.mr.StatisticsResourceAdaptor(rmm.mr.get_current_device_resource())
250+
mr = rmm.mr.StatisticsResourceAdaptor(mr)
242251
rmm.mr.set_current_device_resource(mr)
243252
total_memory = rmm.mr.available_device_memory()[1]
244253
memory_available = {
@@ -307,6 +316,7 @@ def bootstrap_dask_cluster(
307316
*,
308317
spill_device: float = 0.50,
309318
periodic_spill_check: float | None = 1e-3,
319+
oom_protection: bool = True,
310320
enable_statistics: bool = True,
311321
) -> None:
312322
"""
@@ -324,6 +334,9 @@ def bootstrap_dask_cluster(
324334
by the buffer resource. The value of ``periodic_spill_check`` is used as
325335
the pause between checks (in seconds). If None, no periodic spill
326336
check is performed.
337+
oom_protection
338+
Enable out-of-memory protection by using managed memory when the device
339+
memory pool raises OOM errors.
327340
enable_statistics
328341
Whether to track shuffler statistics.
329342
@@ -383,6 +396,7 @@ def bootstrap_dask_cluster(
383396
rmpf_worker_setup,
384397
spill_device=spill_device,
385398
periodic_spill_check=periodic_spill_check,
399+
oom_protection=oom_protection,
386400
enable_statistics=enable_statistics,
387401
)
388402

0 commit comments

Comments
 (0)