Skip to content

Commit 275c9ee

Browse files
optimize bf allocator
1 parent 740b07c commit 275c9ee

File tree

1 file changed

+68
-50
lines changed

1 file changed

+68
-50
lines changed

dipu/torch_dipu/csrc_dipu/runtime/core/allocator/DIPUBFCachingAllocator.cpp

Lines changed: 68 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,15 @@
88
#include <utility>
99
#include <vector>
1010

11-
#include "csrc_dipu/utils/env.hpp"
12-
1311
#include "DIPUCachingAllocator.h"
1412
#include "DIPUSpinMutex.h"
1513

1614
namespace dipu {
1715

1816
inline size_t round_up_to_alignment(size_t nbytes, size_t alignment_size) {
19-
if (nbytes <= 0) {
20-
return alignment_size;
21-
}
2217
return ((nbytes - 1) | (alignment_size - 1)) + 1;
2318
}
19+
2420
class BFCachingAllocatorImpl {
2521
public:
2622
using allocate_fn_t = std::function<void*(size_t)>;
@@ -34,12 +30,23 @@ class BFCachingAllocatorImpl {
3430
// Number of second level bins (linearly)
3531
static constexpr int kNumSubBins = 4;
3632
static constexpr int kLogNumSubBins = 2;
33+
3734
// Allocation parameters
38-
static constexpr int kMinAllocationSize = 512;
39-
static constexpr int kSmallBlockSize = 2 << 20;
40-
static constexpr int kMiddleBlockSize = 20 << 20;
41-
static constexpr int kLargeBlockSize = 200 << 20;
42-
static constexpr int kLargeAlignSize = 1024 << 20;
35+
static constexpr size_t kMinBlockSize =
36+
512; // all sizes are rounded to at least 512 bytes
37+
static constexpr size_t kSmallSize =
38+
1048576; // largest "small" allocation is 1 MiB
39+
static constexpr size_t kSmallBuffer =
40+
2097152; // "small" allocations are packed in 2 MiB blocks
41+
static constexpr size_t kLargeBuffer =
42+
20971520; // "large" allocations may be packed in 20 MiB blocks
43+
static constexpr size_t kMinLargeAlloc =
44+
10485760; // allocations between 1 and 10 MiB may use kLargeBuffer
45+
static constexpr size_t kRoundLarge =
46+
2097152; // round up large allocations to 2 MiB
47+
static constexpr size_t kMaxSplitableBlockSize =
48+
200 << 20; // To further reduce fragmentation, blocks >= 200MB are not
49+
// allowed to be split
4350

4451
size_t cachedBytes = 0;
4552
size_t allocatedBytes = 0;
@@ -143,10 +150,11 @@ class BFCachingAllocatorImpl {
143150
mutable mutex_t mut_;
144151

145152
static size_t roundBytes(size_t nbytes) {
146-
if (nbytes < kLargeBlockSize) {
147-
return round_up_to_alignment(nbytes, kMinAllocationSize);
153+
if (nbytes <= kMinBlockSize) {
154+
return kMinBlockSize;
148155
}
149-
return round_up_to_alignment(nbytes, kSmallBlockSize);
156+
int clz = __builtin_clzll(nbytes - 1);
157+
return (1 << (sizeof(int64_t) - clz));
150158
}
151159

152160
int newChunk(void* ptr, size_t size, size_t stream) {
@@ -169,7 +177,7 @@ class BFCachingAllocatorImpl {
169177
// Big bin range:
170178
// [2^`bigBinIdx`, 2^(`bigBinIdx`+1)), length: 2^`bigBinIdx`
171179
// Split big bin into `kNumSubBins` sub bins
172-
size_t nBlocks = nbytes / kMinAllocationSize;
180+
size_t nBlocks = nbytes / kMinBlockSize;
173181
constexpr int kMaxBinIdx = 63;
174182
int bigBinIdx = kMaxBinIdx - __builtin_clzll(nBlocks);
175183
// If `nbytes` is so large, we just put it into the last
@@ -245,16 +253,22 @@ class BFCachingAllocatorImpl {
245253
return id;
246254
}
247255

248-
void shrink(StreamSetHandle& set) {
256+
void shrink(StreamSetHandle& set, size_t try_release_size = 0) {
257+
size_t released_size = 0;
249258
for (int binHead : set->binHeads_) {
250259
int k = chunks_[binHead].nextChunkInList;
251260
while (k) {
252-
if (chunks_[k].isMonoBlock()) {
253-
releaseOnDevice(chunks_[k].ptr, chunks_[k].size);
261+
auto& chunk_k = chunks_[k];
262+
if (chunk_k.isMonoBlock()) {
263+
released_size += chunk_k.size;
264+
releaseOnDevice(chunk_k.ptr, chunk_k.size);
254265
removeChunkFromBin(k);
255266
recycleIds_.push(k);
267+
if (try_release_size > 0 && released_size >= try_release_size) {
268+
break;
269+
}
256270
}
257-
k = chunks_[k].nextChunkInList;
271+
k = chunk_k.nextChunkInList;
258272
}
259273
}
260274
}
@@ -297,33 +311,39 @@ class BFCachingAllocatorImpl {
297311
return id;
298312
}
299313

300-
int extend(size_t nbytes, StreamSetHandle& set) {
301-
emptyCacheWithoutLock();
302-
bool increased = false;
303-
size_t allocateSize = nbytes;
304-
if (nbytes < kSmallBlockSize) {
305-
allocateSize = kSmallBlockSize;
306-
} else if (nbytes < kMiddleBlockSize) {
307-
allocateSize = kMiddleBlockSize;
308-
} else if (nbytes < kLargeBlockSize) {
309-
allocateSize = round_up_to_alignment(nbytes, kMiddleBlockSize);
310-
} else {
311-
allocateSize = round_up_to_alignment(nbytes, kLargeAlignSize);
314+
size_t getAllocateSize(size_t nbytes) {
315+
if (nbytes <= kSmallSize) {
316+
return kSmallBuffer;
312317
}
318+
if (nbytes < kMinLargeAlloc) {
319+
return kLargeBuffer;
320+
}
321+
return round_up_to_alignment(nbytes, kRoundLarge);
322+
}
313323

314-
size_t currBytes = std::max(nbytes, allocateSize);
315-
void* ptr = allocateOnDevice(currBytes);
324+
int extend(size_t nbytes, StreamSetHandle& set) {
325+
size_t allocateSize = getAllocateSize(nbytes);
326+
327+
void* ptr = allocateOnDevice(allocateSize);
328+
if (!ptr) {
329+
shrink(set, allocateSize);
330+
ptr = allocateOnDevice(allocateSize);
331+
}
332+
if (!ptr) {
333+
shrink(set);
334+
ptr = allocateOnDevice(allocateSize);
335+
}
316336
if (!ptr) {
317-
if (currBytes > nbytes) {
318-
currBytes = nbytes;
319-
ptr = allocateOnDevice(currBytes);
337+
if (allocateSize > nbytes) {
338+
allocateSize = nbytes;
339+
ptr = allocateOnDevice(allocateSize);
320340
}
321341
}
322342
if (!ptr) {
323343
return 0;
324344
}
325345

326-
int id = newChunk(ptr, currBytes, set->id);
346+
int id = newChunk(ptr, allocateSize, set->id);
327347
return id;
328348
}
329349

@@ -378,17 +398,7 @@ class BFCachingAllocatorImpl {
378398
}
379399

380400
if (id) {
381-
int internlalMaxFragnmentSize = 0;
382-
const int chunk_size = static_cast<int>(chunks_[id].size);
383-
if (chunk_size < kSmallBlockSize) {
384-
internlalMaxFragnmentSize = kMinAllocationSize;
385-
} else if (chunk_size < kLargeAlignSize) {
386-
internlalMaxFragnmentSize = kSmallBlockSize;
387-
} else {
388-
internlalMaxFragnmentSize = kLargeAlignSize;
389-
}
390-
if ((chunk_size >= (nbytes << 1)) ||
391-
(chunk_size > (nbytes + internlalMaxFragnmentSize))) {
401+
if (chunks_[id].size >= (nbytes << 1)) {
392402
id = split(id, nbytes);
393403
}
394404
chunks_[id].allocated = true;
@@ -522,6 +532,9 @@ class BFCachingAllocator : public CacheAllocator {
522532
: DataPtrContextBase(allocator, ptr, size), id_(id), nbytes_(nbytes) {}
523533

524534
~Context() {
535+
if (size() <= 0) {
536+
return;
537+
}
525538
auto allocator_ = static_cast<const BFCachingAllocator*>(allocator());
526539
DIPU_DEBUG_ALLOCATOR(8, "BFCachingAllocator: add to async_mem_pool:"
527540
<< ptr() << ", " << size() << " nbytes, id:"
@@ -531,16 +544,21 @@ class BFCachingAllocator : public CacheAllocator {
531544
if (ptr()) {
532545
allocator_->metrics_producer.deallocate(ptr());
533546
std::deque<DIPUEvent> events;
547+
bool record_block = false;
534548
for (auto const& stream : streams()) {
535549
events.emplace_back();
536550
DIPU_DEBUG_ALLOCATOR(8, "BFCachingAllocator: record to stream:"
537551
<< stream.rawstream());
538552
events.back().record(stream);
553+
record_block = true;
539554
}
540555
allocator_->async_mem_pool()->add(std::make_tuple(ptr(), id_),
541556
events);
542557
allocator_->set_memory_allocated(allocator_->memory_allocated() -
543558
nbytes_);
559+
if (!record_block) {
560+
allocator_->restore();
561+
}
544562
}
545563
} else {
546564
DIPU_DEBUG_ALLOCATOR(8,
@@ -552,12 +570,12 @@ class BFCachingAllocator : public CacheAllocator {
552570

553571
friend class Context;
554572

555-
c10::DataPtr allocate(size_t size) const override {
573+
c10::DataPtr allocate(size_t origin_size) const override {
556574
restore();
557575
if (async_mem_pool()->size() > kMaxAsyncResourcePoolLength) {
558576
try_empty_resource_pool();
559577
}
560-
size = getMemoryAlignmentStrategy()->roundBytes(size);
578+
size_t size = getMemoryAlignmentStrategy()->roundBytes(origin_size);
561579
std::tuple<void*, int, size_t> block = impl->allocateRaw(size);
562580
void* ptr = std::get<0>(block);
563581
if (ptr == nullptr && size > 0) {
@@ -583,7 +601,7 @@ class BFCachingAllocator : public CacheAllocator {
583601
deleteBFContext, device());
584602
DIPU_DEBUG_ALLOCATOR(
585603
4, "BFCachingAllocator: malloc "
586-
<< nbytes << ",requires " << size << " nbytes, ptr:" << ptr
604+
<< nbytes << ",requires " << origin_size << " nbytes, ptr:" << ptr
587605
<< ",device:" << device()
588606
<< ",async_mempool.size:" << async_mem_pool()->size());
589607
c10::reportMemoryUsageToProfiler(

0 commit comments

Comments
 (0)