Skip to content

Commit 5487cac

Browse files
griwesjrhemstad
andauthored
thrust/mr: fix the case of reuising a block for a smaller alloc. (#1232) (#1317)
* thrust/mr: fix the case of reuising a block for a smaller alloc. Previously, the pool happily returned a pointer to a larger oversized block than requested, without storing the information that the block is now smaller, which meant that on deallocation, it'd look for the descriptor of the block in the wrong place. This is now fixed by moving the descriptor to always be where deallocation can find it using the user-provided size, and by storing the original size to restore the descriptor to its rightful place when deallocating. Also a drive-by fix for a bug where in certain cases the reallocated cached oversized block wasn't removed from the cached list. Whoops. Kinda surprised this hasn't exploded before. * thrust/mr: add aliases to reused pointer traits in pool.h Co-authored-by: Jake Hemstad <[email protected]>
1 parent d4ca07a commit 5487cac

File tree

2 files changed

+107
-42
lines changed

2 files changed

+107
-42
lines changed

thrust/testing/mr_pool.cu

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,23 +123,26 @@ public:
123123

124124
virtual tracked_pointer<void> do_allocate(std::size_t n, std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT) override
125125
{
126-
ASSERT_EQUAL(static_cast<bool>(id_to_allocate), true);
126+
ASSERT_EQUAL(id_to_allocate || id_to_allocate == -1u, true);
127127

128128
void * raw = upstream.do_allocate(n, alignment);
129129
tracked_pointer<void> ret(raw);
130130
ret.id = id_to_allocate;
131131
ret.size = n;
132132
ret.alignment = alignment;
133133

134-
id_to_allocate = 0;
134+
if (id_to_allocate != -1u)
135+
{
136+
id_to_allocate = 0;
137+
}
135138

136139
return ret;
137140
}
138141

139142
virtual void do_deallocate(tracked_pointer<void> p, std::size_t n, std::size_t alignment = THRUST_MR_DEFAULT_ALIGNMENT) override
140143
{
141-
ASSERT_EQUAL(p.size, n);
142-
ASSERT_EQUAL(p.alignment, alignment);
144+
ASSERT_GEQUAL(p.size, n);
145+
ASSERT_GEQUAL(p.alignment, alignment);
143146

144147
if (id_to_deallocate != 0)
145148
{
@@ -318,6 +321,36 @@ void TestPoolCachingOversized()
318321
upstream.id_to_allocate = 7;
319322
tracked_pointer<void> a9 = pool.do_allocate(2048, 32);
320323
ASSERT_EQUAL(a9.id, 7u);
324+
325+
// make sure that reusing a larger oversized block for a smaller allocation works
326+
// this is NVIDIA/cccl#585
327+
upstream.id_to_allocate = 8;
328+
tracked_pointer<void> a10 = pool.do_allocate(2048 + 16, THRUST_MR_DEFAULT_ALIGNMENT);
329+
pool.do_deallocate(a10, 2048 + 16, THRUST_MR_DEFAULT_ALIGNMENT);
330+
tracked_pointer<void> a11 = pool.do_allocate(2048, THRUST_MR_DEFAULT_ALIGNMENT);
331+
ASSERT_EQUAL(a11.ptr, a10.ptr);
332+
pool.do_deallocate(a11, 2048, THRUST_MR_DEFAULT_ALIGNMENT);
333+
334+
// original minimized reproducer from NVIDIA/cccl#585:
335+
{
336+
upstream.id_to_allocate = -1u;
337+
338+
auto ptr1 = pool.allocate(43920240);
339+
auto ptr2 = pool.allocate(2465264);
340+
pool.deallocate(ptr1, 43920240);
341+
pool.deallocate(ptr2, 2465264);
342+
auto ptr3 = pool.allocate(4930528);
343+
pool.deallocate(ptr3, 4930528);
344+
auto ptr4 = pool.allocate(14640080);
345+
std::memset(thrust::raw_pointer_cast(ptr4), 0xff, 14640080);
346+
347+
auto crash = pool.allocate(4930528);
348+
349+
pool.deallocate(crash, 4930528);
350+
pool.deallocate(ptr4, 14640080);
351+
352+
upstream.id_to_allocate = 0;
353+
}
321354
}
322355

323356
void TestUnsynchronizedPoolCachingOversized()

thrust/thrust/mr/pool.h

Lines changed: 70 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,17 @@ class unsynchronized_pool_resource final
154154

155155
private:
156156
typedef typename Upstream::pointer void_ptr;
157-
typedef typename thrust::detail::pointer_traits<void_ptr>::template rebind<char>::other char_ptr;
157+
typedef thrust::detail::pointer_traits<void_ptr> void_ptr_traits;
158+
typedef typename void_ptr_traits::template rebind<char>::other char_ptr;
158159

159160
struct block_descriptor;
160161
struct chunk_descriptor;
161162
struct oversized_block_descriptor;
162163

163-
typedef typename thrust::detail::pointer_traits<void_ptr>::template rebind<block_descriptor>::other block_descriptor_ptr;
164-
typedef typename thrust::detail::pointer_traits<void_ptr>::template rebind<chunk_descriptor>::other chunk_descriptor_ptr;
165-
typedef typename thrust::detail::pointer_traits<void_ptr>::template rebind<oversized_block_descriptor>::other oversized_block_descriptor_ptr;
164+
typedef typename void_ptr_traits::template rebind<block_descriptor>::other block_descriptor_ptr;
165+
typedef typename void_ptr_traits::template rebind<chunk_descriptor>::other chunk_descriptor_ptr;
166+
typedef typename void_ptr_traits::template rebind<oversized_block_descriptor>::other oversized_block_descriptor_ptr;
167+
typedef thrust::detail::pointer_traits<oversized_block_descriptor_ptr> oversized_block_ptr_traits;
166168

167169
struct block_descriptor
168170
{
@@ -194,6 +196,7 @@ class unsynchronized_pool_resource final
194196
oversized_block_descriptor_ptr prev;
195197
oversized_block_descriptor_ptr next;
196198
oversized_block_descriptor_ptr next_cached;
199+
std::size_t current_size;
197200
};
198201

199202
struct pool
@@ -244,17 +247,20 @@ class unsynchronized_pool_resource final
244247
}
245248

246249
// deallocate cached oversized/overaligned memory
247-
while (detail::pointer_traits<oversized_block_descriptor_ptr>::get(m_oversized))
250+
while (oversized_block_ptr_traits::get(m_oversized))
248251
{
249252
oversized_block_descriptor_ptr alloc = m_oversized;
250253
m_oversized = thrust::raw_reference_cast(*m_oversized).next;
251254

255+
oversized_block_descriptor desc =
256+
thrust::raw_reference_cast(*alloc);
257+
252258
void_ptr p = static_cast<void_ptr>(
253-
static_cast<char_ptr>(
254-
static_cast<void_ptr>(alloc)
255-
) - thrust::raw_reference_cast(*alloc).size
256-
);
257-
m_upstream->do_deallocate(p, thrust::raw_reference_cast(*alloc).size + sizeof(oversized_block_descriptor), thrust::raw_reference_cast(*alloc).alignment);
259+
static_cast<char_ptr>(static_cast<void_ptr>(alloc)) -
260+
desc.current_size);
261+
m_upstream->do_deallocate(
262+
p, desc.size + sizeof(oversized_block_descriptor),
263+
desc.alignment);
258264
}
259265

260266
m_cached_oversized = oversized_block_descriptor_ptr();
@@ -272,7 +278,7 @@ class unsynchronized_pool_resource final
272278
{
273279
oversized_block_descriptor_ptr ptr = m_cached_oversized;
274280
oversized_block_descriptor_ptr * previous = &m_cached_oversized;
275-
while (detail::pointer_traits<oversized_block_descriptor_ptr>::get(ptr))
281+
while (oversized_block_ptr_traits::get(ptr))
276282
{
277283
oversized_block_descriptor desc = *ptr;
278284
bool is_good = desc.size >= bytes && desc.alignment >= alignment;
@@ -305,23 +311,39 @@ class unsynchronized_pool_resource final
305311
{
306312
if (previous != &m_cached_oversized)
307313
{
308-
oversized_block_descriptor previous_desc = **previous;
309-
previous_desc.next_cached = desc.next_cached;
310-
**previous = previous_desc;
314+
*previous = desc.next_cached;
311315
}
312316
else
313317
{
314318
m_cached_oversized = desc.next_cached;
315319
}
316320

317321
desc.next_cached = oversized_block_descriptor_ptr();
322+
323+
auto ret =
324+
static_cast<char_ptr>(static_cast<void_ptr>(ptr)) -
325+
desc.size;
326+
327+
if (bytes != desc.size) {
328+
desc.current_size = bytes;
329+
330+
ptr = static_cast<oversized_block_descriptor_ptr>(
331+
static_cast<void_ptr>(ret + bytes));
332+
333+
if (oversized_block_ptr_traits::get(desc.prev)) {
334+
thrust::raw_reference_cast(*desc.prev).next = ptr;
335+
} else {
336+
m_oversized = ptr;
337+
}
338+
339+
if (oversized_block_ptr_traits::get(desc.next)) {
340+
thrust::raw_reference_cast(*desc.next).prev = ptr;
341+
}
342+
}
343+
318344
*ptr = desc;
319345

320-
return static_cast<void_ptr>(
321-
static_cast<char_ptr>(
322-
static_cast<void_ptr>(ptr)
323-
) - desc.size
324-
);
346+
return static_cast<void_ptr>(ret);
325347
}
326348

327349
previous = &thrust::raw_reference_cast(*ptr).next_cached;
@@ -343,10 +365,11 @@ class unsynchronized_pool_resource final
343365
desc.prev = oversized_block_descriptor_ptr();
344366
desc.next = m_oversized;
345367
desc.next_cached = oversized_block_descriptor_ptr();
368+
desc.current_size = bytes;
346369
*block = desc;
347370
m_oversized = block;
348371

349-
if (detail::pointer_traits<oversized_block_descriptor_ptr>::get(desc.next))
372+
if (oversized_block_ptr_traits::get(desc.next))
350373
{
351374
oversized_block_descriptor next = *desc.next;
352375
next.prev = block;
@@ -439,7 +462,7 @@ class unsynchronized_pool_resource final
439462
assert(detail::is_power_of_2(alignment));
440463

441464
// verify that the pointer is at least as aligned as claimed
442-
assert(reinterpret_cast<detail::intmax_t>(detail::pointer_traits<void_ptr>::get(p)) % alignment == 0);
465+
assert(reinterpret_cast<detail::intmax_t>(void_ptr_traits::get(p)) % alignment == 0);
443466

444467
// the deallocated block is oversized and/or overaligned
445468
if (n > m_options.largest_block_size || alignment > m_options.alignment)
@@ -451,35 +474,44 @@ class unsynchronized_pool_resource final
451474
);
452475

453476
oversized_block_descriptor desc = *block;
477+
assert(desc.current_size == n);
478+
assert(desc.alignment == alignment);
454479

455480
if (m_options.cache_oversized)
456481
{
457482
desc.next_cached = m_cached_oversized;
458-
*block = desc;
483+
484+
if (desc.size != n) {
485+
desc.current_size = desc.size;
486+
block = static_cast<oversized_block_descriptor_ptr>(
487+
static_cast<void_ptr>(static_cast<char_ptr>(p) +
488+
desc.size));
489+
if (oversized_block_ptr_traits::get(desc.prev)) {
490+
thrust::raw_reference_cast(*desc.prev).next = block;
491+
} else {
492+
m_oversized = block;
493+
}
494+
495+
if (oversized_block_ptr_traits::get(desc.next)) {
496+
thrust::raw_reference_cast(*desc.next).prev = block;
497+
}
498+
}
499+
459500
m_cached_oversized = block;
501+
*block = desc;
460502

461503
return;
462504
}
463505

464-
if (!detail::pointer_traits<oversized_block_descriptor_ptr>::get(desc.prev))
465-
{
466-
assert(m_oversized == block);
506+
if (oversized_block_ptr_traits::get(
507+
desc.prev)) {
508+
thrust::raw_reference_cast(*desc.prev).next = desc.next;
509+
} else {
467510
m_oversized = desc.next;
468511
}
469-
else
470-
{
471-
oversized_block_descriptor prev = *desc.prev;
472-
assert(prev.next == block);
473-
prev.next = desc.next;
474-
*desc.prev = prev;
475-
}
476512

477-
if (detail::pointer_traits<oversized_block_descriptor_ptr>::get(desc.next))
478-
{
479-
oversized_block_descriptor next = *desc.next;
480-
assert(next.prev == block);
481-
next.prev = desc.prev;
482-
*desc.next = next;
513+
if (oversized_block_ptr_traits::get(desc.next)) {
514+
thrust::raw_reference_cast(*desc.next).prev = desc.prev;
483515
}
484516

485517
m_upstream->do_deallocate(p, desc.size + sizeof(oversized_block_descriptor), desc.alignment);

0 commit comments

Comments
 (0)