Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions cpp/include/rapidsmpf/buffer/rmm_fallback_resource.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <cstddef>
#include <mutex>
#include <unordered_set>

#include <rmm/error.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

namespace rapidsmpf {


/**
* @brief A device memory resource that uses an alternate upstream resource when the
* primary upstream resource throws `rmm::out_of_memory`.
*
* An instance of this resource must be constructed with two upstream resources to satisfy
* allocation requests.
*
*/
class RmmFallbackResource final : public rmm::mr::device_memory_resource {
public:
/**
* @brief Construct a new `RmmFallbackResource` that uses `primary_upstream`
* to satisfy allocation requests and if that fails with `rmm::out_of_memory`,
* uses `alternate_upstream`.
*
* @param primary_upstream The primary resource used for allocating/deallocating
* device memory
* @param alternate_upstream The alternate resource used for allocating/deallocating
* device memory memory
*/
RmmFallbackResource(
rmm::device_async_resource_ref primary_upstream,
rmm::device_async_resource_ref alternate_upstream
)
: primary_upstream_{primary_upstream}, alternate_upstream_{alternate_upstream} {}

RmmFallbackResource() = delete;
~RmmFallbackResource() override = default;

/**
* @brief Get a reference to the primary upstream resource.
*
* @return Reference to the RMM memory resource.
*/
[[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() const noexcept {
return primary_upstream_;
}

/**
* @brief Get a reference to the alternative upstream resource.
*
* This resource is used when primary upstream resource throws `rmm::out_of_memory`.
*
* @return Reference to the RMM memory resource.
*/
[[nodiscard]] rmm::device_async_resource_ref get_alternate_upstream_resource(
) const noexcept {
return alternate_upstream_;
}

private:
/**
* @brief Allocates memory of size at least `bytes` using the upstream
* resource.
*
* @throws any exceptions thrown from the upstream resources, only
* `rmm::out_of_memory` thrown by the primary upstream is caught.
*
* @param bytes The size, in bytes, of the allocation
* @param stream Stream on which to perform the allocation
* @return void* Pointer to the newly allocated memory
*/
void* do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override {
void* ret{};
try {
ret = primary_upstream_.allocate_async(bytes, stream);
} catch (rmm::out_of_memory const& e) {
ret = alternate_upstream_.allocate_async(bytes, stream);
std::lock_guard<std::mutex> lock(mutex_);
alternate_allocations_.insert(ret);
}
return ret;
}

/**
* @brief Free allocation of size `bytes` pointed to by `ptr`
*
* @param ptr Pointer to be deallocated
* @param bytes Size of the allocation
* @param stream Stream on which to perform the deallocation
*/
void do_deallocate(void* ptr, std::size_t bytes, rmm::cuda_stream_view stream)
override {
std::size_t count{0};
{
std::lock_guard<std::mutex> lock(mutex_);
count = alternate_allocations_.erase(ptr);
}
if (count > 0) {
alternate_upstream_.deallocate_async(ptr, bytes, stream);
} else {
primary_upstream_.deallocate_async(ptr, bytes, stream);
}
}

/**
* @brief Compare the resource to another.
*
* @param other The other resource to compare to
* @return true If the two resources are equivalent
* @return false If the two resources are not equal
*/
[[nodiscard]] bool do_is_equal(rmm::mr::device_memory_resource const& other
) const noexcept override {
if (this == &other) {
return true;
}
auto cast = dynamic_cast<RmmFallbackResource const*>(&other);
if (cast == nullptr) {
return false;
}
return get_upstream_resource() == cast->get_upstream_resource()
&& get_alternate_upstream_resource()
== cast->get_alternate_upstream_resource();
}

std::mutex mutex_;
rmm::device_async_resource_ref primary_upstream_;
rmm::device_async_resource_ref alternate_upstream_;
std::unordered_set<void*> alternate_allocations_;
};


} // namespace rapidsmpf
88 changes: 88 additions & 0 deletions cpp/tests/test_rmm_fallback_resource.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: Apache-2.0
*/


#include <cstddef>
#include <stdexcept>
#include <unordered_set>

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/detail/error.hpp>
#include <rmm/device_buffer.hpp>

#include <rapidsmpf/buffer/rmm_fallback_resource.hpp>


using namespace rapidsmpf;

template <typename ExceptionType>
struct throw_at_limit_resource final : public rmm::mr::device_memory_resource {
throw_at_limit_resource(std::size_t limit) : limit{limit} {}

void* do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override {
if (bytes > limit) {
throw ExceptionType{"foo"};
}
void* ptr{nullptr};
RMM_CUDA_TRY_ALLOC(cudaMallocAsync(&ptr, bytes, stream));
allocs.insert(ptr);
return ptr;
}

void do_deallocate(void* ptr, std::size_t, rmm::cuda_stream_view) override {
RMM_ASSERT_CUDA_SUCCESS(cudaFree(ptr));
allocs.erase(ptr);
}

[[nodiscard]] bool do_is_equal(rmm::mr::device_memory_resource const& other
) const noexcept override {
return this == &other;
}

const std::size_t limit;
std::unordered_set<void*> allocs{};
};

TEST(FailureAlternateTest, TrackBothUpstreams) {
throw_at_limit_resource<rmm::out_of_memory> primary_mr{100};
throw_at_limit_resource<rmm::out_of_memory> alternate_mr{1000};
RmmFallbackResource mr{primary_mr, alternate_mr};

// Check that a small allocation goes to the primary resource.
{
void* a1 = mr.allocate(10);
EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{a1});
EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{});
mr.deallocate(a1, 10);
EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{});
EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{});
}

// Check that a large allocation goes to the alternate resource.
{
void* a1 = mr.allocate(200);
EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{});
EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{a1});
mr.deallocate(a1, 200);
EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{});
EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{});
}

// Check that we get an error when the allocation cannot fit the
// primary or the alternate resource.
EXPECT_THROW(mr.allocate(2000), rmm::out_of_memory);
}

TEST(FailureAlternateTest, DifferentExceptionTypes) {
throw_at_limit_resource<std::invalid_argument> primary_mr{100};
throw_at_limit_resource<rmm::out_of_memory> alternate_mr{1000};
RmmFallbackResource mr{primary_mr, alternate_mr};

// Check that `RmmFallbackResource` only catch `rmm::out_of_memory` exceptions.
EXPECT_THROW(mr.allocate(200), std::invalid_argument);
}
4 changes: 3 additions & 1 deletion python/rapidsmpf/rapidsmpf/buffer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# SPDX-License-Identifier: Apache-2.0
# =================================================================================

set(cython_modules buffer.pyx packed_data.pyx resource.pyx spill_manager.pyx)
set(cython_modules buffer.pyx packed_data.pyx resource.pyx spill_manager.pyx
rmm_fallback_resource.pyx
)

rapids_cython_create_modules(
CXX
Expand Down
15 changes: 15 additions & 0 deletions python/rapidsmpf/rapidsmpf/buffer/rmm_fallback_resource.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from rmm.pylibrmm.memory_resource import DeviceMemoryResource

class RmmFallbackResource:
def __init__(
self,
upstream_mr: DeviceMemoryResource,
alternate_upstream_mr: DeviceMemoryResource,
): ...
@property
def get_upstream(self) -> DeviceMemoryResource: ...
@property
def get_alternate_upstream(self) -> DeviceMemoryResource: ...
59 changes: 59 additions & 0 deletions python/rapidsmpf/rapidsmpf/buffer/rmm_fallback_resource.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from rmm.pylibrmm.memory_resource cimport (DeviceMemoryResource,
UpstreamResourceAdaptor,
device_memory_resource)


cdef extern from "<rapidsmpf/buffer/rmm_fallback_resource.hpp>" nogil:
cdef cppclass cpp_RmmFallbackResource"rapidsmpf::RmmFallbackResource"(
device_memory_resource
):
# Notice, `RmmFallbackResource` takes `device_async_resource_ref` as
# upstream arguments but we define them here as `device_memory_resource*`
# and rely on implicit type conversion.
cpp_RmmFallbackResource(
device_memory_resource* upstream_mr,
device_memory_resource* alternate_upstream_mr,
) except +


cdef class RmmFallbackResource(UpstreamResourceAdaptor):
cdef readonly DeviceMemoryResource alternate_upstream_mr

def __cinit__(
self,
DeviceMemoryResource upstream_mr,
DeviceMemoryResource alternate_upstream_mr,
):
# Note, `upstream_mr is None` is checked by `UpstreamResourceAdaptor`.
if alternate_upstream_mr is None:
raise Exception("Argument `alternate_upstream_mr` must not be None")
self.alternate_upstream_mr = alternate_upstream_mr

self.c_obj.reset(
new cpp_RmmFallbackResource(
upstream_mr.get_mr(),
alternate_upstream_mr.get_mr(),
)
)

def __init__(
self,
DeviceMemoryResource upstream_mr,
DeviceMemoryResource alternate_upstream_mr,
):
"""
A memory resource that uses an alternate resource when memory allocation fails.
Parameters
----------
upstream : DeviceMemoryResource
The primary resource used for allocating/deallocating device memory
alternate_upstream : DeviceMemoryResource
The alternate resource used when the primary fails to allocate
"""
pass

cpdef DeviceMemoryResource get_alternate_upstream(self):
return self.alternate_upstream_mr
16 changes: 15 additions & 1 deletion python/rapidsmpf/rapidsmpf/integrations/dask/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from rapidsmpf.buffer.buffer import MemoryType
from rapidsmpf.buffer.resource import BufferResource, LimitAvailableMemory
from rapidsmpf.buffer.rmm_fallback_resource import RmmFallbackResource
from rapidsmpf.buffer.spill_collection import SpillCollection
from rapidsmpf.communicator.ucxx import barrier, get_root_ucxx_address, new_communicator
from rapidsmpf.integrations.dask import _compat
Expand Down Expand Up @@ -187,6 +188,7 @@ def rmpf_worker_setup(
*,
spill_device: float,
periodic_spill_check: float,
oom_protection: bool,
enable_statistics: bool,
) -> None:
"""
Expand All @@ -204,6 +206,9 @@ def rmpf_worker_setup(
by the buffer resource. The value of ``periodic_spill_check`` is used as
the pause between checks (in seconds). If None, no periodic spill check
is performed.
oom_protection
Enable out-of-memory protection by using managed memory when the device
memory pool raises OOM errors.
enable_statistics
Whether to track shuffler statistics.

Expand Down Expand Up @@ -236,9 +241,13 @@ def rmpf_worker_setup(
assert ctx.comm is not None
ctx.progress_thread = ProgressThread(ctx.comm, ctx.statistics)

mr = rmm.mr.get_current_device_resource()
if oom_protection:
mr = RmmFallbackResource(mr, rmm.mr.ManagedMemoryResource())

# Setup a buffer_resource.
# Wrap the current RMM resource in statistics adaptor.
mr = rmm.mr.StatisticsResourceAdaptor(rmm.mr.get_current_device_resource())
mr = rmm.mr.StatisticsResourceAdaptor(mr)
rmm.mr.set_current_device_resource(mr)
total_memory = rmm.mr.available_device_memory()[1]
memory_available = {
Expand Down Expand Up @@ -307,6 +316,7 @@ def bootstrap_dask_cluster(
*,
spill_device: float = 0.50,
periodic_spill_check: float | None = 1e-3,
oom_protection: bool = False,
enable_statistics: bool = True,
) -> None:
"""
Expand All @@ -324,6 +334,9 @@ def bootstrap_dask_cluster(
by the buffer resource. The value of ``periodic_spill_check`` is used as
the pause between checks (in seconds). If None, no periodic spill
check is performed.
oom_protection
Enable out-of-memory protection by using managed memory when the device
memory pool raises OOM errors.
enable_statistics
Whether to track shuffler statistics.

Expand Down Expand Up @@ -383,6 +396,7 @@ def bootstrap_dask_cluster(
rmpf_worker_setup,
spill_device=spill_device,
periodic_spill_check=periodic_spill_check,
oom_protection=oom_protection,
enable_statistics=enable_statistics,
)

Expand Down
Loading