From 77e19f9cb9c5f858d30a05700094f3e7dd97aba5 Mon Sep 17 00:00:00 2001 From: "silu.zsl" Date: Fri, 5 Sep 2025 15:46:48 +0800 Subject: [PATCH] feat: support reuse within queries --- rtp_llm/cpp/cache/BlockCache.cc | 1 - rtp_llm/cpp/cache/BlockCache.h | 16 ++++- rtp_llm/cpp/cache/CacheManager.cc | 31 ++++++-- rtp_llm/cpp/cache/CacheManager.h | 39 ++++++++-- rtp_llm/cpp/cache/test/CacheManagerTest.cc | 71 +++++++++++++++++-- .../engine_base/stream/StreamCacheResource.cc | 47 +++++++++++- .../engine_base/stream/StreamCacheResource.h | 24 ++----- .../stream/test/StreamCacheResourceTest.cc | 12 ++-- rtp_llm/frontend/token_processor.py | 10 ++- 9 files changed, 206 insertions(+), 45 deletions(-) diff --git a/rtp_llm/cpp/cache/BlockCache.cc b/rtp_llm/cpp/cache/BlockCache.cc index 226fd518d..5634ebccd 100644 --- a/rtp_llm/cpp/cache/BlockCache.cc +++ b/rtp_llm/cpp/cache/BlockCache.cc @@ -61,7 +61,6 @@ std::vector BlockCache::put(CacheItem& item) { if (item.token_list.empty() || item.block_indices.empty()) { return {}; } - item.item_key = hashVector(item.token_list); std::lock_guard lock(mutex_); diff --git a/rtp_llm/cpp/cache/BlockCache.h b/rtp_llm/cpp/cache/BlockCache.h index 64fe46339..8397117a6 100644 --- a/rtp_llm/cpp/cache/BlockCache.h +++ b/rtp_llm/cpp/cache/BlockCache.h @@ -6,7 +6,7 @@ #include #include #include - +#include #include "rtp_llm/cpp/utils/LRUCache.h" #include "rtp_llm/cpp/utils/AssertUtils.h" @@ -19,6 +19,20 @@ struct CacheItem { std::vector loss; bool is_resident = false; size_t item_key; + std::string debugString() const { + std::stringstream debug_string; + debug_string << "FreeInfo item_key: " << item_key << ", token_ids: "; + debug_string << " cache_keys: "; + for (auto& v : cache_key) { + debug_string << v << ", "; + } + debug_string << " block_indices: "; + for (auto& v : block_indices) { + debug_string << v << ", "; + } + + return debug_string.str(); + } }; const size_t kCacheMaxCapacity = 1000000; diff --git a/rtp_llm/cpp/cache/CacheManager.cc b/rtp_llm/cpp/cache/CacheManager.cc index ec832b141..e459f3e2b 100644 --- a/rtp_llm/cpp/cache/CacheManager.cc +++ b/rtp_llm/cpp/cache/CacheManager.cc @@ -257,6 +257,10 @@ void CacheManager::incrBlockRefCounter(const std::vector& indices) { allocator_->incrBlockRefCounter(indices); } +void CacheManager::decrBlockRefCounter(const std::vector& indices) { + allocator_->decrBlockRefCounter(indices); +} + void CacheManager::incrQueryRefCounter(const std::vector& blocks) { std::set unique_blocks(blocks.begin(), blocks.end()); for (auto block : unique_blocks) { @@ -265,13 +269,11 @@ void CacheManager::incrQueryRefCounter(const std::vector& blocks) { available_blocks_--; } } - query_ref_counter_.incrementRefCounter(blocks); } void CacheManager::decrQueryRefCounter(const std::vector& blocks) { query_ref_counter_.decrementRefCounter(blocks); - std::set unique_blocks(blocks.begin(), blocks.end()); for (auto block : unique_blocks) { if (query_ref_counter_.getRefCounter(block) == 0) { @@ -349,15 +351,34 @@ void CacheManager::freeWithCache(FreeInfo& free_info) { std::lock_guard guard(mutex_); decrQueryRefCounter(free_info.block_indices); free_info.is_resident = false; - insertIntoCache(free_info); + insertCacheThenFree(free_info); +} + +void CacheManager::insertIntoCache(FreeInfo& free_info) { + if (free_info.token_ids.size() > 1) { + size_t token_len = free_info.token_ids.size() - 1; + size_t block_len = std::min(std::min(free_info.block_indices.size(), free_info.cache_keys.size()), + token_len / seq_size_per_block_); + token_len = block_len * seq_size_per_block_; + CacheItem item{{free_info.token_ids.begin(), free_info.token_ids.begin() + token_len}, + {free_info.block_indices.begin(), free_info.block_indices.begin() + block_len}, + {free_info.cache_keys.begin(), free_info.cache_keys.begin() + block_len}, + free_info.loss.empty() ? + free_info.loss : + std::vector{free_info.loss.begin(), free_info.loss.begin() + token_len}, + free_info.is_resident}; + incrBlockRefCounter(item.block_indices); + std::vector indices = block_cache_.put(item); + allocator_->free(indices); + } } void CacheManager::insertResidentCache(FreeInfo& free_info) { free_info.is_resident = true; - insertIntoCache(free_info); + insertCacheThenFree(free_info); } -void CacheManager::insertIntoCache(FreeInfo& free_info) { +void CacheManager::insertCacheThenFree(FreeInfo& free_info) { if (free_info.token_ids.size() > 1) { size_t token_len = free_info.token_ids.size() - 1; size_t block_len = std::min(std::min(free_info.block_indices.size(), free_info.cache_keys.size()), diff --git a/rtp_llm/cpp/cache/CacheManager.h b/rtp_llm/cpp/cache/CacheManager.h index dc87c7e6f..000cf2fb5 100644 --- a/rtp_llm/cpp/cache/CacheManager.h +++ b/rtp_llm/cpp/cache/CacheManager.h @@ -17,6 +17,7 @@ #include "kmonitor/client/MetricsReporter.h" #include "rtp_llm/cpp/cache/KvCacheInfo.h" #include "rtp_llm/cpp/cache/KVCacheAllocator.h" +#include namespace rtp_llm { @@ -30,6 +31,17 @@ class CacheManager { size_t remote_reuse_length = 0; std::vector cache_blocks; std::vector loss; + std::string debugString() const { + std::stringstream debug_string; + debug_string << "MatchInfo reuse_length: " << reuse_length << ", local_reuse_length: " << local_reuse_length + << ",remote_reuse_length: " << remote_reuse_length << ";cache_blocks: "; + + for (auto& v : cache_blocks) { + debug_string << v << ", "; + } + + return debug_string.str(); + } }; struct AdvancedMallocInfo { @@ -84,6 +96,23 @@ class CacheManager { bool is_resident = false; const std::string adapter_name; bool enable_3fs = false; + std::string debugString() const { + std::stringstream debug_string; + debug_string << "FreeInfo request_id: " << request_id << ", token_ids: "; + /*for (auto& v : token_ids) { + debug_string << v << ", "; + }*/ + debug_string << " cache_keys: "; + for (auto& v : cache_keys) { + debug_string << v << ", "; + } + debug_string << " block_indices: "; + for (auto& v : block_indices) { + debug_string << v << ", "; + } + + return debug_string.str(); + } }; public: @@ -112,6 +141,7 @@ class CacheManager { // returns the number of new available blocks if given blocks are freed size_t newFreeBlocks(const std::vector& indice); void insertResidentCache(FreeInfo& free_info); + void insertIntoCache(FreeInfo& free_info); virtual void setKVBlockValue(int block_index, int layer_id, rtp_llm::Buffer& k_buffer, rtp_llm::Buffer& v_buffer); virtual void setKVBlockValue(int block_index, rtp_llm::Buffer& k_buffer, rtp_llm::Buffer& v_buffer); @@ -151,11 +181,7 @@ class CacheManager { void maybeFreeBlockFromCache(int nums); void freeWithoutLock(const std::vector& indice); - void insertIntoCache(FreeInfo& free_info); - - void incrQueryRefCounter(const std::vector& blocks); - void decrQueryRefCounter(const std::vector& blocks); - + void insertCacheThenFree(FreeInfo& free_info); void reportMetricsLoop(); const std::shared_ptr& kvCacheAllocator() const; @@ -172,6 +198,9 @@ class CacheManager { std::string getLoraCkptPath(const std::string& adapter_name) const; void incrBlockRefCounter(const std::vector& blocks); + void decrBlockRefCounter(const std::vector& blocks); + void incrQueryRefCounter(const std::vector& blocks); + void decrQueryRefCounter(const std::vector& blocks); protected: CacheConfig config_; diff --git a/rtp_llm/cpp/cache/test/CacheManagerTest.cc b/rtp_llm/cpp/cache/test/CacheManagerTest.cc index 4e35078f9..372a0a6f6 100644 --- a/rtp_llm/cpp/cache/test/CacheManagerTest.cc +++ b/rtp_llm/cpp/cache/test/CacheManagerTest.cc @@ -1099,7 +1099,7 @@ TEST_F(CacheManagerTest, testInsertIntoCache_TokenLenLessThan1_FreeAllBlocks) { std::vector token_ids = {1000}; // size <= 1 triggers full free auto cache_keys = constructCacheKey(cache_manager, token_ids); CacheManager::FreeInfo free_info(request_id, token_ids, cache_keys, idx); - cache_manager.insertIntoCache(free_info); + cache_manager.insertCacheThenFree(free_info); ASSERT_EQ(cache_manager.freeBlockNums(), free0); ASSERT_FALSE(cache_manager.blockCache().hasKey({1000})); @@ -1120,7 +1120,7 @@ TEST_F(CacheManagerTest, testInsertIntoCache_PutToBlockCache) { std::vector token_ids = {200, 201, 202}; auto cache_keys = constructCacheKey(cache_manager, token_ids); CacheManager::FreeInfo free_info(request_id, token_ids, cache_keys, idx); - cache_manager.insertIntoCache(free_info); + cache_manager.insertCacheThenFree(free_info); // One tail block freed back ASSERT_EQ(cache_manager.freeBlockNums(), free_after_malloc + 1); @@ -1145,19 +1145,82 @@ TEST_F(CacheManagerTest, testInsertIntoCache_PutToBlockCacheTwice) { std::vector token_ids = {200, 201, 202}; auto cache_keys = constructCacheKey(cache_manager, token_ids); CacheManager::FreeInfo free_info(request_id, token_ids, cache_keys, idx); - cache_manager.insertIntoCache(free_info); + cache_manager.insertCacheThenFree(free_info); ASSERT_TRUE(cache_manager.blockCache().hasKey({200, 201})); ASSERT_TRUE(allocator->blockRefCounter().getRefCounter(idx[0]) == 2); ASSERT_TRUE(allocator->blockRefCounter().getRefCounter(idx[1]) == 2); - cache_manager.insertIntoCache(free_info); + cache_manager.insertCacheThenFree(free_info); ASSERT_TRUE(cache_manager.blockCache().hasKey({200, 201})); // ref count should be decremented ASSERT_TRUE(allocator->blockRefCounter().getRefCounter(idx[0]) == 1); ASSERT_TRUE(allocator->blockRefCounter().getRefCounter(idx[1]) == 1); } +// Test insertIntoCache with valid input (token_ids.size() > 1) +TEST_F(CacheManagerTest, testInsertIntoCache_ValidInput) { + auto cache_config = initConfig(); + cache_config.block_nums = 10; + cache_config.seq_size_per_block = 1; + CacheManager cache_manager(cache_config, device_); + auto allocator = cache_manager.kvCacheAllocator(); + + auto [ok, idx] = cache_manager.mallocIndex({request_id, 2}); + ASSERT_TRUE(ok); + + std::vector token_ids = {200, 201, 202}; + auto cache_keys = constructCacheKey(cache_manager, token_ids); + CacheManager::FreeInfo free_info(request_id, token_ids, cache_keys, idx); + + // Before insertIntoCache, cache should not have the key + ASSERT_FALSE(cache_manager.blockCache().hasKey({200, 201})); + + // Call insertIntoCache + cache_manager.insertIntoCache(free_info); + + // After insertIntoCache, cache should have the key + ASSERT_TRUE(cache_manager.blockCache().hasKey({200, 201})); + + // Check ref counter is incremented + + ASSERT_TRUE(allocator->blockRefCounter().getRefCounter(idx[0]) == 2); + ASSERT_TRUE(allocator->blockRefCounter().getRefCounter(idx[0]) == 2); +} + +// Test insertIntoCache with token_ids.size() <= 1 (should not insert) +TEST_F(CacheManagerTest, testInsertIntoCache_TokenLenLessThanEqualOne) { + auto cache_config = initConfig(); + cache_config.block_nums = 10; + cache_config.seq_size_per_block = 1; + CacheManager cache_manager(cache_config, device_); + + auto [ok, idx] = cache_manager.mallocIndex({request_id, 2}); + ASSERT_TRUE(ok); + + // Test with empty token_ids + std::vector token_ids_empty = {}; + auto cache_keys_empty = constructCacheKey(cache_manager, token_ids_empty); + CacheManager::FreeInfo free_info_empty(request_id, token_ids_empty, cache_keys_empty, idx); + + // Call insertIntoCache with empty token_ids + cache_manager.insertIntoCache(free_info_empty); + + // Cache should remain empty + ASSERT_EQ(cache_manager.cacheItemNum(), 0u); + + // Test with single token_id + std::vector token_ids_single = {100}; + auto cache_keys_single = constructCacheKey(cache_manager, token_ids_single); + CacheManager::FreeInfo free_info_single(request_id, token_ids_single, cache_keys_single, idx); + + // Call insertIntoCache with single token_id + cache_manager.insertIntoCache(free_info_single); + + // Cache should remain empty + ASSERT_EQ(cache_manager.cacheItemNum(), 0u); +} + // loss present -> do NOT put to dist TEST_F(CacheManagerTest, testInsertIntoCache_LossNotEmpty_PutToBlockCache_NotPutToDistKvCache) { auto cache_config = initConfig(); diff --git a/rtp_llm/cpp/engine_base/stream/StreamCacheResource.cc b/rtp_llm/cpp/engine_base/stream/StreamCacheResource.cc index 6a86d3bcd..128cc9d86 100644 --- a/rtp_llm/cpp/engine_base/stream/StreamCacheResource.cc +++ b/rtp_llm/cpp/engine_base/stream/StreamCacheResource.cc @@ -12,6 +12,29 @@ void StreamCacheResource::init(int batch_size) { constructCacheKey(); } +std::string StreamCacheResource::debugString() const { + std::stringstream debug_string; + debug_string << "StreamCacheResource { stream_id: " << stream_->streamId() + << ",need_release_resource: " << need_release_resource_ << ", batch_resource: ["; + + for (size_t i = 0; i < batch_resource_.batchSize(); i++) { + debug_string << " ["; + for (size_t j = 0; j < batch_resource_.batch_block_id[i].size(); j++) { + debug_string << batch_resource_.batch_block_id[i][j] << ", "; + } + debug_string << "],"; + } + debug_string << ", cache_keys: "; + for (size_t i = 0; i < batch_resource_.batchSize(); i++) { + debug_string << " ["; + for (size_t j = 0; j < batch_resource_.cache_keys[i].size(); j++) { + debug_string << batch_resource_.cache_keys[i][j] << ", "; + } + debug_string << "],"; + } + debug_string << "}"; + return debug_string.str(); +} void StreamCacheResource::freeBatchBlocks(size_t batch_id, vector& blocks) { if (blocks.empty()) { return; @@ -135,7 +158,28 @@ absl::StatusOr StreamCacheResource::initKVBlock(int token_capacity, size_t } } - return incrKVBlock(token_capacity, reserve_step); + auto res = incrKVBlock(token_capacity, reserve_step); + if (reuseCache() + && !stream_->calculateLoss()) { // 如果开启query内部reuse cache,在insert cache的时候还不知道loss,无法复用 + insertIntoCache(); + } + return res; +} + +void StreamCacheResource::insertIntoCache() { + + for (size_t batch_id = 0; batch_id < batch_resource_.batchSize(); batch_id++) { + const auto& blocks = batch_resource_.blocks(batch_id); + auto tokens_id = stream_->completeTokenIdsVec(batch_id); + auto cache_keys = stream_->cacheKeys(batch_id); + if (!last_block_aligned_ && !cache_keys.empty()) { + cache_keys.pop_back(); + } + vector loss; + CacheManager::FreeInfo free_info(stream_->streamId(), tokens_id, cache_keys, blocks, loss, adapter_name_); + resource_context_.cache_manager->insertIntoCache(free_info); + } + return; } absl::StatusOr StreamCacheResource::incrKVBlock(int token_capacity, size_t reserve_step) { @@ -367,6 +411,7 @@ void StreamCacheResource::reConstructCacheKeys() { } auto seq_size_per_block = seqSizePerBlock(); auto total_blocks = stream_->seqLength() / seq_size_per_block; + for (size_t i = 0; i < stream_->currentBatchSize(); ++i) { if (!last_block_aligned_ && !batch_resource_.cache_keys[i].empty()) { batch_resource_.cache_keys[i].pop_back(); diff --git a/rtp_llm/cpp/engine_base/stream/StreamCacheResource.h b/rtp_llm/cpp/engine_base/stream/StreamCacheResource.h index ac77d60e6..1fccc040e 100644 --- a/rtp_llm/cpp/engine_base/stream/StreamCacheResource.h +++ b/rtp_llm/cpp/engine_base/stream/StreamCacheResource.h @@ -8,7 +8,6 @@ #include namespace rtp_llm { - class GenerateStream; struct ResourceContext { @@ -42,6 +41,7 @@ class StreamCacheResource { const std::vector& cacheKeys(int32_t batch_id) const; absl::StatusOr initKVBlock(int token_capacity, size_t reserve_step = 0); absl::StatusOr incrKVBlock(int token_capacity, size_t reserve_step = 0); + void insertIntoCache(); void fakeInitKVBlock(); int tryReleaseKVBlock(size_t nums); absl::Status releaseSequenceKVCache(size_t total_seq_len, size_t release_seq_len); @@ -103,25 +103,9 @@ class StreamCacheResource { stream_ = stream; } - bool reuseCache() const; - bool enable3FS() const; - - std::string debugString() const { - std::stringstream debug_string; - debug_string << "StreamCacheResource {" - << "need_release_resource: " << need_release_resource_ << ", batch_resource: ["; - - for (size_t i = 0; i < batch_resource_.batchSize(); i++) { - debug_string << " ["; - for (size_t j = 0; j < batch_resource_.batch_block_id[i].size(); j++) { - debug_string << batch_resource_.batch_block_id[i][j] << " "; - } - debug_string << "],"; - } - - debug_string << "}"; - return debug_string.str(); - } + bool reuseCache() const; + bool enable3FS() const; + std::string debugString() const; private: GenerateStream* stream_; diff --git a/rtp_llm/cpp/engine_base/stream/test/StreamCacheResourceTest.cc b/rtp_llm/cpp/engine_base/stream/test/StreamCacheResourceTest.cc index b00802066..4986d1a27 100644 --- a/rtp_llm/cpp/engine_base/stream/test/StreamCacheResourceTest.cc +++ b/rtp_llm/cpp/engine_base/stream/test/StreamCacheResourceTest.cc @@ -311,11 +311,12 @@ TEST_F(StreamCacheResourceTest, testReuseCache) { stream_->releaseResource(); ASSERT_EQ(cache_manager_->freeBlockNums(), 3); - ASSERT_EQ(cache_manager_->cacheItemNum(), 2); + ASSERT_EQ(cache_manager_->cacheItemNum(), 3); + ASSERT_TRUE(cache_manager_->blockCache().hasKey({1, 2, 3, 4})); ASSERT_TRUE(cache_manager_->blockCache().hasKey({1, 2, 3, 4, 5, 6, 7, 8})); ASSERT_TRUE(cache_manager_->blockCache().hasKey({1, 2, 3, 4, 5, 6, 9, 10})); - ASSERT_EQ(allocator_->blockRefCounter().getRefCounter(1), 2); - ASSERT_EQ(allocator_->blockRefCounter().getRefCounter(2), 2); + ASSERT_EQ(allocator_->blockRefCounter().getRefCounter(1), 3); + ASSERT_EQ(allocator_->blockRefCounter().getRefCounter(2), 3); ASSERT_EQ(allocator_->blockRefCounter().getRefCounter(3), 2); ASSERT_EQ(allocator_->blockRefCounter().getRefCounter(4), 1); ASSERT_EQ(allocator_->blockRefCounter().getRefCounter(5), 0); @@ -380,7 +381,8 @@ TEST_F(StreamCacheResourceTest, testReuseCacheWithFastGen) { stream_->releaseResource(); ASSERT_EQ(cache_manager_->freeBlockNums(), 5); - ASSERT_EQ(cache_manager_->cacheItemNum(), 1); + ASSERT_EQ(cache_manager_->cacheItemNum(), 2); + ASSERT_TRUE(cache_manager_->blockCache().hasKey({1, 2, 3, 4})); ASSERT_TRUE(cache_manager_->blockCache().hasKey({1, 2, 3, 4, 5, 6})); // test another stream @@ -444,7 +446,7 @@ TEST_F(StreamCacheResourceTest, testReuseCacheWithFastGen) { stream_->setPaused(); ASSERT_EQ(cache_manager_->freeBlockNums(), 3); ASSERT_EQ(cache_manager_->availableBlockNums(), 8); - ASSERT_EQ(cache_manager_->cacheItemNum(), 2); + ASSERT_EQ(cache_manager_->cacheItemNum(), 4); ASSERT_EQ(stream_->maxBlockSize(), 0); // first chunk again diff --git a/rtp_llm/frontend/token_processor.py b/rtp_llm/frontend/token_processor.py index 930535aed..a19ee8f73 100644 --- a/rtp_llm/frontend/token_processor.py +++ b/rtp_llm/frontend/token_processor.py @@ -71,9 +71,13 @@ def decode_tokens( return_incremental: bool = False, ): if not self.has_num_beams: - self.ouput_tokens_list[i] = np.concatenate( - (self.ouput_tokens_list[i], tokens), axis=1 - ) + # Handle case when self.ouput_tokens_list[i] is empty + if self.ouput_tokens_list[i].size == 0: + self.ouput_tokens_list[i] = tokens + else: + self.ouput_tokens_list[i] = np.concatenate( + (self.ouput_tokens_list[i], tokens), axis=1 + ) tokens = self.ouput_tokens_list[i] tokens = remove_padding_eos_with_numpy( tokens, self.special_tokens.eos_token_id