-
Notifications
You must be signed in to change notification settings - Fork 22
OOM protection by fallback to managed memory. #287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
rapids-bot
merged 10 commits into
rapidsai:branch-25.06
from
madsbk:rmm_fallback_resource
May 20, 2025
Merged
Changes from 4 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
e228c61
RmmFallbackResource
madsbk 70bff9a
always OOM
madsbk c33c48c
python bindings
madsbk 2a9a7c6
bootstrap_dask_cluster: oom_protection
madsbk 04b607e
Merge branch 'branch-25.06' of github.com:rapidsai/rapidsmpf into rmm…
madsbk c4a2ffd
default False
madsbk e39e1ec
docs
madsbk 8d2887d
cleanup
madsbk 8afa638
doc
madsbk 84e2050
test_except_type
madsbk File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| /** | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <cstddef> | ||
| #include <mutex> | ||
| #include <unordered_set> | ||
|
|
||
| #include <rmm/mr/device/device_memory_resource.hpp> | ||
| #include <rmm/resource_ref.hpp> | ||
|
|
||
| namespace rapidsmpf { | ||
|
|
||
|
|
||
| /** | ||
| * @brief A device memory resource that uses an alternate upstream resource when the | ||
| * primary upstream resource throws a specified exception type. | ||
| * | ||
| * An instance of this resource must be constructed with two upstream resources to satisfy | ||
| * allocation requests. | ||
| * | ||
| */ | ||
| class RmmFallbackResource final : public rmm::mr::device_memory_resource { | ||
| public: | ||
| using exception_type = | ||
| rmm::out_of_memory; ///< The type of exception this object catches/throws | ||
|
|
||
| /** | ||
| * @brief Construct a new `RmmFallbackResource` that uses `primary_upstream` | ||
| * to satisfy allocation requests and if that fails with `ExceptionType`, uses | ||
pentschev marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| * `alternate_upstream`. | ||
| * | ||
| * @param primary_upstream The primary resource used for allocating/deallocating | ||
| * device memory | ||
| * @param alternate_upstream The alternate resource used for allocating/deallocating | ||
| * device memory memory | ||
| */ | ||
| RmmFallbackResource( | ||
| rmm::device_async_resource_ref primary_upstream, | ||
| rmm::device_async_resource_ref alternate_upstream | ||
| ) | ||
| : primary_upstream_{primary_upstream}, alternate_upstream_{alternate_upstream} {} | ||
|
|
||
| RmmFallbackResource() = delete; | ||
| ~RmmFallbackResource() override = default; | ||
|
|
||
| /** | ||
| * @brief Move constructor for RmmFallbackResource. | ||
| * | ||
| * @param other The RmmFallbackResource instance to move from. | ||
| * @return Reference to the moved instance. | ||
| */ | ||
| RmmFallbackResource& operator=(RmmFallbackResource&& other) noexcept = default; | ||
|
|
||
| /** | ||
| * @brief Get a reference to the primary upstream resource. | ||
| * | ||
| * @return Reference to the RMM memory resource. | ||
| */ | ||
| [[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() const noexcept { | ||
| return primary_upstream_; | ||
| } | ||
|
|
||
| RmmFallbackResource(RmmFallbackResource const&) = delete; | ||
| RmmFallbackResource& operator=(RmmFallbackResource const&) = delete; | ||
|
|
||
| /** | ||
| * @brief Get a reference to the alternative upstream resource. | ||
| * | ||
| * This resource is used when primary upstream resource throws `exception_type`. | ||
| * | ||
| * @return Reference to the RMM memory resource. | ||
| */ | ||
| [[nodiscard]] rmm::device_async_resource_ref get_alternate_upstream_resource( | ||
| ) const noexcept { | ||
| return alternate_upstream_; | ||
| } | ||
|
|
||
| private: | ||
| /** | ||
| * @brief Allocates memory of size at least `bytes` using the upstream | ||
| * resource. | ||
| * | ||
| * @throws any exceptions thrown from the upstream resources, only `exception_type` | ||
| * thrown by the primary upstream is caught. | ||
| * | ||
| * @param bytes The size, in bytes, of the allocation | ||
| * @param stream Stream on which to perform the allocation | ||
| * @return void* Pointer to the newly allocated memory | ||
| */ | ||
| void* do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override { | ||
| void* ret{}; | ||
| try { | ||
| ret = primary_upstream_.allocate_async(bytes, stream); | ||
| } catch (exception_type const& e) { | ||
| ret = alternate_upstream_.allocate_async(bytes, stream); | ||
| std::lock_guard<std::mutex> lock(mtx_); | ||
| alternate_allocations_.insert(ret); | ||
| } | ||
| return ret; | ||
| } | ||
|
|
||
| /** | ||
| * @brief Free allocation of size `bytes` pointed to by `ptr` | ||
| * | ||
| * @param ptr Pointer to be deallocated | ||
| * @param bytes Size of the allocation | ||
| * @param stream Stream on which to perform the deallocation | ||
| */ | ||
| void do_deallocate(void* ptr, std::size_t bytes, rmm::cuda_stream_view stream) | ||
| override { | ||
| std::size_t count{0}; | ||
| { | ||
| std::lock_guard<std::mutex> lock(mtx_); | ||
| count = alternate_allocations_.erase(ptr); | ||
| } | ||
| if (count > 0) { | ||
| alternate_upstream_.deallocate_async(ptr, bytes, stream); | ||
| } else { | ||
| primary_upstream_.deallocate_async(ptr, bytes, stream); | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * @brief Compare the resource to another. | ||
| * | ||
| * @param other The other resource to compare to | ||
| * @return true If the two resources are equivalent | ||
| * @return false If the two resources are not equal | ||
| */ | ||
| [[nodiscard]] bool do_is_equal(rmm::mr::device_memory_resource const& other | ||
| ) const noexcept override { | ||
| if (this == &other) { | ||
| return true; | ||
| } | ||
| auto cast = dynamic_cast<RmmFallbackResource const*>(&other); | ||
| if (cast == nullptr) { | ||
| return false; | ||
| } | ||
| return get_upstream_resource() == cast->get_upstream_resource() | ||
| && get_alternate_upstream_resource() | ||
| == cast->get_alternate_upstream_resource(); | ||
| } | ||
|
|
||
| rmm::device_async_resource_ref primary_upstream_; | ||
| rmm::device_async_resource_ref alternate_upstream_; | ||
| std::unordered_set<void*> alternate_allocations_; | ||
| mutable std::mutex mtx_; | ||
| }; | ||
|
|
||
|
|
||
| } // namespace rapidsmpf | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| /** | ||
| * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
|
|
||
| #include <cstddef> | ||
| #include <stdexcept> | ||
| #include <unordered_set> | ||
|
|
||
| #include <gmock/gmock.h> | ||
| #include <gtest/gtest.h> | ||
|
|
||
| #include <rmm/cuda_stream_view.hpp> | ||
| #include <rmm/detail/error.hpp> | ||
| #include <rmm/device_buffer.hpp> | ||
|
|
||
| #include <rapidsmpf/buffer/rmm_fallback_resource.hpp> | ||
|
|
||
|
|
||
| using namespace rapidsmpf; | ||
|
|
||
| template <typename ExceptionType> | ||
| struct throw_at_limit_resource final : public rmm::mr::device_memory_resource { | ||
| throw_at_limit_resource(std::size_t limit) : limit{limit} {} | ||
|
|
||
| void* do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override { | ||
| if (bytes > limit) { | ||
| throw ExceptionType{"foo"}; | ||
| } | ||
| void* ptr{nullptr}; | ||
| RMM_CUDA_TRY_ALLOC(cudaMallocAsync(&ptr, bytes, stream)); | ||
| allocs.insert(ptr); | ||
| return ptr; | ||
| } | ||
|
|
||
| void do_deallocate(void* ptr, std::size_t, rmm::cuda_stream_view) override { | ||
| RMM_ASSERT_CUDA_SUCCESS(cudaFree(ptr)); | ||
| allocs.erase(ptr); | ||
| } | ||
|
|
||
| [[nodiscard]] bool do_is_equal(rmm::mr::device_memory_resource const& other | ||
| ) const noexcept override { | ||
| return this == &other; | ||
| } | ||
|
|
||
| const std::size_t limit; | ||
| std::unordered_set<void*> allocs{}; | ||
| }; | ||
|
|
||
| TEST(FailureAlternateTest, TrackBothUpstreams) { | ||
| throw_at_limit_resource<rmm::out_of_memory> primary_mr{100}; | ||
| throw_at_limit_resource<rmm::out_of_memory> alternate_mr{1000}; | ||
| RmmFallbackResource mr{primary_mr, alternate_mr}; | ||
|
|
||
| // Check that a small allocation goes to the primary resource | ||
| { | ||
| void* a1 = mr.allocate(10); | ||
| EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{a1}); | ||
| EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{}); | ||
| mr.deallocate(a1, 10); | ||
| EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{}); | ||
| EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{}); | ||
| } | ||
|
|
||
| // Check that a large allocation goes to the alternate resource | ||
| { | ||
| void* a1 = mr.allocate(200); | ||
| EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{}); | ||
| EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{a1}); | ||
| mr.deallocate(a1, 200); | ||
| EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{}); | ||
| EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{}); | ||
| } | ||
|
|
||
| // Check that the exceptions raised by the alternate isn't caught | ||
pentschev marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| EXPECT_THROW(mr.allocate(2000), rmm::out_of_memory); | ||
| } | ||
|
|
||
| TEST(FailureAlternateTest, DifferentExceptionTypes) { | ||
| throw_at_limit_resource<std::invalid_argument> primary_mr{100}; | ||
| throw_at_limit_resource<rmm::out_of_memory> alternate_mr{1000}; | ||
| RmmFallbackResource mr{primary_mr, alternate_mr}; | ||
|
|
||
| // Check that only `rmm::out_of_memory` exceptions are caught | ||
pentschev marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| EXPECT_THROW(mr.allocate(200), std::invalid_argument); | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
15 changes: 15 additions & 0 deletions
15
python/rapidsmpf/rapidsmpf/buffer/rmm_fallback_resource.pyi
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from rmm.pylibrmm.memory_resource import DeviceMemoryResource | ||
|
|
||
| class RmmFallbackResource: | ||
| def __init__( | ||
| self, | ||
| upstream_mr: DeviceMemoryResource, | ||
| alternate_upstream_mr: DeviceMemoryResource, | ||
| ): ... | ||
| @property | ||
| def get_upstream(self) -> DeviceMemoryResource: ... | ||
| @property | ||
| def get_alternate_upstream(self) -> DeviceMemoryResource: ... |
58 changes: 58 additions & 0 deletions
58
python/rapidsmpf/rapidsmpf/buffer/rmm_fallback_resource.pyx
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from rmm.pylibrmm.memory_resource cimport (DeviceMemoryResource, | ||
| UpstreamResourceAdaptor, | ||
| device_memory_resource) | ||
|
|
||
|
|
||
| cdef extern from "<rapidsmpf/buffer/rmm_fallback_resource.hpp>" nogil: | ||
| cdef cppclass cpp_RmmFallbackResource"rapidsmpf::RmmFallbackResource"( | ||
| device_memory_resource | ||
| ): | ||
| # Notice, `RmmFallbackResource` takes `device_async_resource_ref` as | ||
| # upstream arguments but we define them here as `device_memory_resource*` | ||
| # and rely on implicit type conversion. | ||
| cpp_RmmFallbackResource( | ||
| device_memory_resource* upstream_mr, | ||
| device_memory_resource* alternate_upstream_mr, | ||
| ) except + | ||
|
|
||
|
|
||
| cdef class RmmFallbackResource(UpstreamResourceAdaptor): | ||
| cdef readonly DeviceMemoryResource alternate_upstream_mr | ||
|
|
||
| def __cinit__( | ||
| self, | ||
| DeviceMemoryResource upstream_mr, | ||
| DeviceMemoryResource alternate_upstream_mr, | ||
| ): | ||
| if (alternate_upstream_mr is None): | ||
| raise Exception("Argument `alternate_upstream_mr` must not be None") | ||
pentschev marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.alternate_upstream_mr = alternate_upstream_mr | ||
|
|
||
| self.c_obj.reset( | ||
| new cpp_RmmFallbackResource( | ||
| upstream_mr.get_mr(), | ||
| alternate_upstream_mr.get_mr(), | ||
| ) | ||
| ) | ||
|
|
||
| def __init__( | ||
| self, | ||
| DeviceMemoryResource upstream_mr, | ||
| DeviceMemoryResource alternate_upstream_mr, | ||
| ): | ||
| """ | ||
| A memory resource that uses an alternate resource when memory allocation fails. | ||
| Parameters | ||
| ---------- | ||
| upstream : DeviceMemoryResource | ||
| The primary resource used for allocating/deallocating device memory | ||
| alternate_upstream : DeviceMemoryResource | ||
| The alternate resource used when the primary fails to allocate | ||
| """ | ||
| pass | ||
|
|
||
| cpdef DeviceMemoryResource get_alternate_upstream(self): | ||
| return self.alternate_upstream_mr | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.