Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions cpp/include/rapidsmpf/buffer/rmm_fallback_resource.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <cstddef>
#include <mutex>
#include <unordered_set>

#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

namespace rapidsmpf {


/**
* @brief A device memory resource that uses an alternate upstream resource when the
* primary upstream resource throws a specified exception type.
*
* An instance of this resource must be constructed with two upstream resources to satisfy
* allocation requests.
*
*/
class RmmFallbackResource final : public rmm::mr::device_memory_resource {
public:
using exception_type =
rmm::out_of_memory; ///< The type of exception this object catches/throws

/**
* @brief Construct a new `RmmFallbackResource` that uses `primary_upstream`
* to satisfy allocation requests and if that fails with `ExceptionType`, uses
* `alternate_upstream`.
*
* @param primary_upstream The primary resource used for allocating/deallocating
* device memory
* @param alternate_upstream The alternate resource used for allocating/deallocating
* device memory memory
*/
RmmFallbackResource(
rmm::device_async_resource_ref primary_upstream,
rmm::device_async_resource_ref alternate_upstream
)
: primary_upstream_{primary_upstream}, alternate_upstream_{alternate_upstream} {}

RmmFallbackResource() = delete;
~RmmFallbackResource() override = default;

/**
* @brief Move constructor for RmmFallbackResource.
*
* @param other The RmmFallbackResource instance to move from.
* @return Reference to the moved instance.
*/
RmmFallbackResource& operator=(RmmFallbackResource&& other) noexcept = default;

/**
* @brief Get a reference to the primary upstream resource.
*
* @return Reference to the RMM memory resource.
*/
[[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() const noexcept {
return primary_upstream_;
}

RmmFallbackResource(RmmFallbackResource const&) = delete;
RmmFallbackResource& operator=(RmmFallbackResource const&) = delete;

/**
* @brief Get a reference to the alternative upstream resource.
*
* This resource is used when primary upstream resource throws `exception_type`.
*
* @return Reference to the RMM memory resource.
*/
[[nodiscard]] rmm::device_async_resource_ref get_alternate_upstream_resource(
) const noexcept {
return alternate_upstream_;
}

private:
/**
* @brief Allocates memory of size at least `bytes` using the upstream
* resource.
*
* @throws any exceptions thrown from the upstream resources, only `exception_type`
* thrown by the primary upstream is caught.
*
* @param bytes The size, in bytes, of the allocation
* @param stream Stream on which to perform the allocation
* @return void* Pointer to the newly allocated memory
*/
void* do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override {
void* ret{};
try {
ret = primary_upstream_.allocate_async(bytes, stream);
} catch (exception_type const& e) {
ret = alternate_upstream_.allocate_async(bytes, stream);
std::lock_guard<std::mutex> lock(mtx_);
alternate_allocations_.insert(ret);
}
return ret;
}

/**
* @brief Free allocation of size `bytes` pointed to by `ptr`
*
* @param ptr Pointer to be deallocated
* @param bytes Size of the allocation
* @param stream Stream on which to perform the deallocation
*/
void do_deallocate(void* ptr, std::size_t bytes, rmm::cuda_stream_view stream)
override {
std::size_t count{0};
{
std::lock_guard<std::mutex> lock(mtx_);
count = alternate_allocations_.erase(ptr);
}
if (count > 0) {
alternate_upstream_.deallocate_async(ptr, bytes, stream);
} else {
primary_upstream_.deallocate_async(ptr, bytes, stream);
}
}

/**
* @brief Compare the resource to another.
*
* @param other The other resource to compare to
* @return true If the two resources are equivalent
* @return false If the two resources are not equal
*/
[[nodiscard]] bool do_is_equal(rmm::mr::device_memory_resource const& other
) const noexcept override {
if (this == &other) {
return true;
}
auto cast = dynamic_cast<RmmFallbackResource const*>(&other);
if (cast == nullptr) {
return false;
}
return get_upstream_resource() == cast->get_upstream_resource()
&& get_alternate_upstream_resource()
== cast->get_alternate_upstream_resource();
}

rmm::device_async_resource_ref primary_upstream_;
rmm::device_async_resource_ref alternate_upstream_;
std::unordered_set<void*> alternate_allocations_;
mutable std::mutex mtx_;
};


} // namespace rapidsmpf
87 changes: 87 additions & 0 deletions cpp/tests/test_rmm_fallback_resource.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: Apache-2.0
*/


#include <cstddef>
#include <stdexcept>
#include <unordered_set>

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/detail/error.hpp>
#include <rmm/device_buffer.hpp>

#include <rapidsmpf/buffer/rmm_fallback_resource.hpp>


using namespace rapidsmpf;

template <typename ExceptionType>
struct throw_at_limit_resource final : public rmm::mr::device_memory_resource {
throw_at_limit_resource(std::size_t limit) : limit{limit} {}

void* do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override {
if (bytes > limit) {
throw ExceptionType{"foo"};
}
void* ptr{nullptr};
RMM_CUDA_TRY_ALLOC(cudaMallocAsync(&ptr, bytes, stream));
allocs.insert(ptr);
return ptr;
}

void do_deallocate(void* ptr, std::size_t, rmm::cuda_stream_view) override {
RMM_ASSERT_CUDA_SUCCESS(cudaFree(ptr));
allocs.erase(ptr);
}

[[nodiscard]] bool do_is_equal(rmm::mr::device_memory_resource const& other
) const noexcept override {
return this == &other;
}

const std::size_t limit;
std::unordered_set<void*> allocs{};
};

TEST(FailureAlternateTest, TrackBothUpstreams) {
throw_at_limit_resource<rmm::out_of_memory> primary_mr{100};
throw_at_limit_resource<rmm::out_of_memory> alternate_mr{1000};
RmmFallbackResource mr{primary_mr, alternate_mr};

// Check that a small allocation goes to the primary resource
{
void* a1 = mr.allocate(10);
EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{a1});
EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{});
mr.deallocate(a1, 10);
EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{});
EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{});
}

// Check that a large allocation goes to the alternate resource
{
void* a1 = mr.allocate(200);
EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{});
EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{a1});
mr.deallocate(a1, 200);
EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{});
EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{});
}

// Check that the exceptions raised by the alternate isn't caught
EXPECT_THROW(mr.allocate(2000), rmm::out_of_memory);
}

TEST(FailureAlternateTest, DifferentExceptionTypes) {
throw_at_limit_resource<std::invalid_argument> primary_mr{100};
throw_at_limit_resource<rmm::out_of_memory> alternate_mr{1000};
RmmFallbackResource mr{primary_mr, alternate_mr};

// Check that only `rmm::out_of_memory` exceptions are caught
EXPECT_THROW(mr.allocate(200), std::invalid_argument);
}
4 changes: 3 additions & 1 deletion python/rapidsmpf/rapidsmpf/buffer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# SPDX-License-Identifier: Apache-2.0
# =================================================================================

set(cython_modules buffer.pyx packed_data.pyx resource.pyx spill_manager.pyx)
set(cython_modules buffer.pyx packed_data.pyx resource.pyx spill_manager.pyx
rmm_fallback_resource.pyx
)

rapids_cython_create_modules(
CXX
Expand Down
15 changes: 15 additions & 0 deletions python/rapidsmpf/rapidsmpf/buffer/rmm_fallback_resource.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from rmm.pylibrmm.memory_resource import DeviceMemoryResource

class RmmFallbackResource:
def __init__(
self,
upstream_mr: DeviceMemoryResource,
alternate_upstream_mr: DeviceMemoryResource,
): ...
@property
def get_upstream(self) -> DeviceMemoryResource: ...
@property
def get_alternate_upstream(self) -> DeviceMemoryResource: ...
58 changes: 58 additions & 0 deletions python/rapidsmpf/rapidsmpf/buffer/rmm_fallback_resource.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from rmm.pylibrmm.memory_resource cimport (DeviceMemoryResource,
UpstreamResourceAdaptor,
device_memory_resource)


cdef extern from "<rapidsmpf/buffer/rmm_fallback_resource.hpp>" nogil:
cdef cppclass cpp_RmmFallbackResource"rapidsmpf::RmmFallbackResource"(
device_memory_resource
):
# Notice, `RmmFallbackResource` takes `device_async_resource_ref` as
# upstream arguments but we define them here as `device_memory_resource*`
# and rely on implicit type conversion.
cpp_RmmFallbackResource(
device_memory_resource* upstream_mr,
device_memory_resource* alternate_upstream_mr,
) except +


cdef class RmmFallbackResource(UpstreamResourceAdaptor):
cdef readonly DeviceMemoryResource alternate_upstream_mr

def __cinit__(
self,
DeviceMemoryResource upstream_mr,
DeviceMemoryResource alternate_upstream_mr,
):
if (alternate_upstream_mr is None):
raise Exception("Argument `alternate_upstream_mr` must not be None")
self.alternate_upstream_mr = alternate_upstream_mr

self.c_obj.reset(
new cpp_RmmFallbackResource(
upstream_mr.get_mr(),
alternate_upstream_mr.get_mr(),
)
)

def __init__(
self,
DeviceMemoryResource upstream_mr,
DeviceMemoryResource alternate_upstream_mr,
):
"""
A memory resource that uses an alternate resource when memory allocation fails.
Parameters
----------
upstream : DeviceMemoryResource
The primary resource used for allocating/deallocating device memory
alternate_upstream : DeviceMemoryResource
The alternate resource used when the primary fails to allocate
"""
pass

cpdef DeviceMemoryResource get_alternate_upstream(self):
return self.alternate_upstream_mr
16 changes: 15 additions & 1 deletion python/rapidsmpf/rapidsmpf/integrations/dask/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from rapidsmpf.buffer.buffer import MemoryType
from rapidsmpf.buffer.resource import BufferResource, LimitAvailableMemory
from rapidsmpf.buffer.rmm_fallback_resource import RmmFallbackResource
from rapidsmpf.buffer.spill_collection import SpillCollection
from rapidsmpf.communicator.ucxx import barrier, get_root_ucxx_address, new_communicator
from rapidsmpf.integrations.dask import _compat
Expand Down Expand Up @@ -187,6 +188,7 @@ def rmpf_worker_setup(
*,
spill_device: float,
periodic_spill_check: float,
oom_protection: bool,
enable_statistics: bool,
) -> None:
"""
Expand All @@ -204,6 +206,9 @@ def rmpf_worker_setup(
by the buffer resource. The value of ``periodic_spill_check`` is used as
the pause between checks (in seconds). If None, no periodic spill check
is performed.
oom_protection
Enable out-of-memory protection by using managed memory when the device
memory pool raises OOM errors.
enable_statistics
Whether to track shuffler statistics.

Expand Down Expand Up @@ -236,9 +241,13 @@ def rmpf_worker_setup(
assert ctx.comm is not None
ctx.progress_thread = ProgressThread(ctx.comm, ctx.statistics)

mr = rmm.mr.get_current_device_resource()
if oom_protection:
mr = RmmFallbackResource(mr, rmm.mr.ManagedMemoryResource())

# Setup a buffer_resource.
# Wrap the current RMM resource in statistics adaptor.
mr = rmm.mr.StatisticsResourceAdaptor(rmm.mr.get_current_device_resource())
mr = rmm.mr.StatisticsResourceAdaptor(mr)
rmm.mr.set_current_device_resource(mr)
total_memory = rmm.mr.available_device_memory()[1]
memory_available = {
Expand Down Expand Up @@ -307,6 +316,7 @@ def bootstrap_dask_cluster(
*,
spill_device: float = 0.50,
periodic_spill_check: float | None = 1e-3,
oom_protection: bool = True,
enable_statistics: bool = True,
) -> None:
"""
Expand All @@ -324,6 +334,9 @@ def bootstrap_dask_cluster(
by the buffer resource. The value of ``periodic_spill_check`` is used as
the pause between checks (in seconds). If None, no periodic spill
check is performed.
oom_protection
Enable out-of-memory protection by using managed memory when the device
memory pool raises OOM errors.
enable_statistics
Whether to track shuffler statistics.

Expand Down Expand Up @@ -383,6 +396,7 @@ def bootstrap_dask_cluster(
rmpf_worker_setup,
spill_device=spill_device,
periodic_spill_check=periodic_spill_check,
oom_protection=oom_protection,
enable_statistics=enable_statistics,
)

Expand Down
Loading