Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion rtp_llm/cpp/cache/BlockCache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ std::vector<int> BlockCache::put(CacheItem& item) {
if (item.token_list.empty() || item.block_indices.empty()) {
return {};
}

item.item_key = hashVector(item.token_list);

std::lock_guard<std::mutex> lock(mutex_);
Expand Down
16 changes: 15 additions & 1 deletion rtp_llm/cpp/cache/BlockCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <vector>
#include <mutex>
#include <unordered_map>

#include <sstream>
#include "rtp_llm/cpp/utils/LRUCache.h"
#include "rtp_llm/cpp/utils/AssertUtils.h"

Expand All @@ -19,6 +19,20 @@ struct CacheItem {
std::vector<float> loss;
bool is_resident = false;
size_t item_key;
std::string debugString() const {
std::stringstream debug_string;
debug_string << "FreeInfo item_key: " << item_key << ", token_ids: ";
debug_string << " cache_keys: ";
for (auto& v : cache_key) {
debug_string << v << ", ";
}
debug_string << " block_indices: ";
for (auto& v : block_indices) {
debug_string << v << ", ";
}

return debug_string.str();
}
};

const size_t kCacheMaxCapacity = 1000000;
Expand Down
31 changes: 26 additions & 5 deletions rtp_llm/cpp/cache/CacheManager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,10 @@ void CacheManager::incrBlockRefCounter(const std::vector<int>& indices) {
allocator_->incrBlockRefCounter(indices);
}

void CacheManager::decrBlockRefCounter(const std::vector<int>& indices) {
allocator_->decrBlockRefCounter(indices);
}

void CacheManager::incrQueryRefCounter(const std::vector<int>& blocks) {
std::set<int> unique_blocks(blocks.begin(), blocks.end());
for (auto block : unique_blocks) {
Expand All @@ -265,13 +269,11 @@ void CacheManager::incrQueryRefCounter(const std::vector<int>& blocks) {
available_blocks_--;
}
}

query_ref_counter_.incrementRefCounter(blocks);
}

void CacheManager::decrQueryRefCounter(const std::vector<int>& blocks) {
query_ref_counter_.decrementRefCounter(blocks);

std::set<int> unique_blocks(blocks.begin(), blocks.end());
for (auto block : unique_blocks) {
if (query_ref_counter_.getRefCounter(block) == 0) {
Expand Down Expand Up @@ -349,15 +351,34 @@ void CacheManager::freeWithCache(FreeInfo& free_info) {
std::lock_guard<std::mutex> guard(mutex_);
decrQueryRefCounter(free_info.block_indices);
free_info.is_resident = false;
insertIntoCache(free_info);
insertCacheThenFree(free_info);
}

void CacheManager::insertIntoCache(FreeInfo& free_info) {
if (free_info.token_ids.size() > 1) {
size_t token_len = free_info.token_ids.size() - 1;
size_t block_len = std::min(std::min(free_info.block_indices.size(), free_info.cache_keys.size()),
token_len / seq_size_per_block_);
token_len = block_len * seq_size_per_block_;
CacheItem item{{free_info.token_ids.begin(), free_info.token_ids.begin() + token_len},
{free_info.block_indices.begin(), free_info.block_indices.begin() + block_len},
{free_info.cache_keys.begin(), free_info.cache_keys.begin() + block_len},
free_info.loss.empty() ?
free_info.loss :
std::vector<float>{free_info.loss.begin(), free_info.loss.begin() + token_len},
free_info.is_resident};
incrBlockRefCounter(item.block_indices);
std::vector<int> indices = block_cache_.put(item);
allocator_->free(indices);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 free 是不对的?

}
}

void CacheManager::insertResidentCache(FreeInfo& free_info) {
free_info.is_resident = true;
insertIntoCache(free_info);
insertCacheThenFree(free_info);
}

void CacheManager::insertIntoCache(FreeInfo& free_info) {
void CacheManager::insertCacheThenFree(FreeInfo& free_info) {
if (free_info.token_ids.size() > 1) {
size_t token_len = free_info.token_ids.size() - 1;
size_t block_len = std::min(std::min(free_info.block_indices.size(), free_info.cache_keys.size()),
Expand Down
39 changes: 34 additions & 5 deletions rtp_llm/cpp/cache/CacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "kmonitor/client/MetricsReporter.h"
#include "rtp_llm/cpp/cache/KvCacheInfo.h"
#include "rtp_llm/cpp/cache/KVCacheAllocator.h"
#include <sstream>

namespace rtp_llm {

Expand All @@ -30,6 +31,17 @@ class CacheManager {
size_t remote_reuse_length = 0;
std::vector<int> cache_blocks;
std::vector<float> loss;
std::string debugString() const {
std::stringstream debug_string;
debug_string << "MatchInfo reuse_length: " << reuse_length << ", local_reuse_length: " << local_reuse_length
<< ",remote_reuse_length: " << remote_reuse_length << ";cache_blocks: ";

for (auto& v : cache_blocks) {
debug_string << v << ", ";
}

return debug_string.str();
}
};

struct AdvancedMallocInfo {
Expand Down Expand Up @@ -84,6 +96,23 @@ class CacheManager {
bool is_resident = false;
const std::string adapter_name;
bool enable_3fs = false;
std::string debugString() const {
std::stringstream debug_string;
debug_string << "FreeInfo request_id: " << request_id << ", token_ids: ";
/*for (auto& v : token_ids) {
debug_string << v << ", ";
}*/
debug_string << " cache_keys: ";
for (auto& v : cache_keys) {
debug_string << v << ", ";
}
debug_string << " block_indices: ";
for (auto& v : block_indices) {
debug_string << v << ", ";
}

return debug_string.str();
}
};

public:
Expand Down Expand Up @@ -112,6 +141,7 @@ class CacheManager {
// returns the number of new available blocks if given blocks are freed
size_t newFreeBlocks(const std::vector<int>& indice);
void insertResidentCache(FreeInfo& free_info);
void insertIntoCache(FreeInfo& free_info);

virtual void setKVBlockValue(int block_index, int layer_id, rtp_llm::Buffer& k_buffer, rtp_llm::Buffer& v_buffer);
virtual void setKVBlockValue(int block_index, rtp_llm::Buffer& k_buffer, rtp_llm::Buffer& v_buffer);
Expand Down Expand Up @@ -151,11 +181,7 @@ class CacheManager {
void maybeFreeBlockFromCache(int nums);

void freeWithoutLock(const std::vector<int>& indice);
void insertIntoCache(FreeInfo& free_info);

void incrQueryRefCounter(const std::vector<int>& blocks);
void decrQueryRefCounter(const std::vector<int>& blocks);

void insertCacheThenFree(FreeInfo& free_info);
void reportMetricsLoop();

const std::shared_ptr<KVCacheAllocator>& kvCacheAllocator() const;
Expand All @@ -172,6 +198,9 @@ class CacheManager {
std::string getLoraCkptPath(const std::string& adapter_name) const;

void incrBlockRefCounter(const std::vector<int>& blocks);
void decrBlockRefCounter(const std::vector<int>& blocks);
void incrQueryRefCounter(const std::vector<int>& blocks);
void decrQueryRefCounter(const std::vector<int>& blocks);

protected:
CacheConfig config_;
Expand Down
71 changes: 67 additions & 4 deletions rtp_llm/cpp/cache/test/CacheManagerTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ TEST_F(CacheManagerTest, testInsertIntoCache_TokenLenLessThan1_FreeAllBlocks) {
std::vector<int> token_ids = {1000}; // size <= 1 triggers full free
auto cache_keys = constructCacheKey(cache_manager, token_ids);
CacheManager::FreeInfo free_info(request_id, token_ids, cache_keys, idx);
cache_manager.insertIntoCache(free_info);
cache_manager.insertCacheThenFree(free_info);

ASSERT_EQ(cache_manager.freeBlockNums(), free0);
ASSERT_FALSE(cache_manager.blockCache().hasKey({1000}));
Expand All @@ -1120,7 +1120,7 @@ TEST_F(CacheManagerTest, testInsertIntoCache_PutToBlockCache) {
std::vector<int> token_ids = {200, 201, 202};
auto cache_keys = constructCacheKey(cache_manager, token_ids);
CacheManager::FreeInfo free_info(request_id, token_ids, cache_keys, idx);
cache_manager.insertIntoCache(free_info);
cache_manager.insertCacheThenFree(free_info);

// One tail block freed back
ASSERT_EQ(cache_manager.freeBlockNums(), free_after_malloc + 1);
Expand All @@ -1145,19 +1145,82 @@ TEST_F(CacheManagerTest, testInsertIntoCache_PutToBlockCacheTwice) {
std::vector<int> token_ids = {200, 201, 202};
auto cache_keys = constructCacheKey(cache_manager, token_ids);
CacheManager::FreeInfo free_info(request_id, token_ids, cache_keys, idx);
cache_manager.insertIntoCache(free_info);
cache_manager.insertCacheThenFree(free_info);

ASSERT_TRUE(cache_manager.blockCache().hasKey({200, 201}));
ASSERT_TRUE(allocator->blockRefCounter().getRefCounter(idx[0]) == 2);
ASSERT_TRUE(allocator->blockRefCounter().getRefCounter(idx[1]) == 2);

cache_manager.insertIntoCache(free_info);
cache_manager.insertCacheThenFree(free_info);
ASSERT_TRUE(cache_manager.blockCache().hasKey({200, 201}));
// ref count should be decremented
ASSERT_TRUE(allocator->blockRefCounter().getRefCounter(idx[0]) == 1);
ASSERT_TRUE(allocator->blockRefCounter().getRefCounter(idx[1]) == 1);
}

// Test insertIntoCache with valid input (token_ids.size() > 1)
TEST_F(CacheManagerTest, testInsertIntoCache_ValidInput) {
auto cache_config = initConfig();
cache_config.block_nums = 10;
cache_config.seq_size_per_block = 1;
CacheManager cache_manager(cache_config, device_);
auto allocator = cache_manager.kvCacheAllocator();

auto [ok, idx] = cache_manager.mallocIndex({request_id, 2});
ASSERT_TRUE(ok);

std::vector<int> token_ids = {200, 201, 202};
auto cache_keys = constructCacheKey(cache_manager, token_ids);
CacheManager::FreeInfo free_info(request_id, token_ids, cache_keys, idx);

// Before insertIntoCache, cache should not have the key
ASSERT_FALSE(cache_manager.blockCache().hasKey({200, 201}));

// Call insertIntoCache
cache_manager.insertIntoCache(free_info);

// After insertIntoCache, cache should have the key
ASSERT_TRUE(cache_manager.blockCache().hasKey({200, 201}));

// Check ref counter is incremented

ASSERT_TRUE(allocator->blockRefCounter().getRefCounter(idx[0]) == 2);
ASSERT_TRUE(allocator->blockRefCounter().getRefCounter(idx[0]) == 2);
}

// Test insertIntoCache with token_ids.size() <= 1 (should not insert)
TEST_F(CacheManagerTest, testInsertIntoCache_TokenLenLessThanEqualOne) {
auto cache_config = initConfig();
cache_config.block_nums = 10;
cache_config.seq_size_per_block = 1;
CacheManager cache_manager(cache_config, device_);

auto [ok, idx] = cache_manager.mallocIndex({request_id, 2});
ASSERT_TRUE(ok);

// Test with empty token_ids
std::vector<int> token_ids_empty = {};
auto cache_keys_empty = constructCacheKey(cache_manager, token_ids_empty);
CacheManager::FreeInfo free_info_empty(request_id, token_ids_empty, cache_keys_empty, idx);

// Call insertIntoCache with empty token_ids
cache_manager.insertIntoCache(free_info_empty);

// Cache should remain empty
ASSERT_EQ(cache_manager.cacheItemNum(), 0u);

// Test with single token_id
std::vector<int> token_ids_single = {100};
auto cache_keys_single = constructCacheKey(cache_manager, token_ids_single);
CacheManager::FreeInfo free_info_single(request_id, token_ids_single, cache_keys_single, idx);

// Call insertIntoCache with single token_id
cache_manager.insertIntoCache(free_info_single);

// Cache should remain empty
ASSERT_EQ(cache_manager.cacheItemNum(), 0u);
}

// loss present -> do NOT put to dist
TEST_F(CacheManagerTest, testInsertIntoCache_LossNotEmpty_PutToBlockCache_NotPutToDistKvCache) {
auto cache_config = initConfig();
Expand Down
47 changes: 46 additions & 1 deletion rtp_llm/cpp/engine_base/stream/StreamCacheResource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,29 @@ void StreamCacheResource::init(int batch_size) {
constructCacheKey();
}

std::string StreamCacheResource::debugString() const {
std::stringstream debug_string;
debug_string << "StreamCacheResource { stream_id: " << stream_->streamId()
<< ",need_release_resource: " << need_release_resource_ << ", batch_resource: [";

for (size_t i = 0; i < batch_resource_.batchSize(); i++) {
debug_string << " [";
for (size_t j = 0; j < batch_resource_.batch_block_id[i].size(); j++) {
debug_string << batch_resource_.batch_block_id[i][j] << ", ";
}
debug_string << "],";
}
debug_string << ", cache_keys: ";
for (size_t i = 0; i < batch_resource_.batchSize(); i++) {
debug_string << " [";
for (size_t j = 0; j < batch_resource_.cache_keys[i].size(); j++) {
debug_string << batch_resource_.cache_keys[i][j] << ", ";
}
debug_string << "],";
}
debug_string << "}";
return debug_string.str();
}
void StreamCacheResource::freeBatchBlocks(size_t batch_id, vector<int>& blocks) {
if (blocks.empty()) {
return;
Expand Down Expand Up @@ -135,7 +158,28 @@ absl::StatusOr<int> StreamCacheResource::initKVBlock(int token_capacity, size_t
}
}

return incrKVBlock(token_capacity, reserve_step);
auto res = incrKVBlock(token_capacity, reserve_step);
if (reuseCache()
&& !stream_->calculateLoss()) { // 如果开启query内部reuse cache,在insert cache的时候还不知道loss,无法复用
insertIntoCache();
}
return res;
}

void StreamCacheResource::insertIntoCache() {

for (size_t batch_id = 0; batch_id < batch_resource_.batchSize(); batch_id++) {
const auto& blocks = batch_resource_.blocks(batch_id);
auto tokens_id = stream_->completeTokenIdsVec(batch_id);
auto cache_keys = stream_->cacheKeys(batch_id);
if (!last_block_aligned_ && !cache_keys.empty()) {
cache_keys.pop_back();
}
vector<float> loss;
CacheManager::FreeInfo free_info(stream_->streamId(), tokens_id, cache_keys, blocks, loss, adapter_name_);
resource_context_.cache_manager->insertIntoCache(free_info);
}
return;
}

absl::StatusOr<int> StreamCacheResource::incrKVBlock(int token_capacity, size_t reserve_step) {
Expand Down Expand Up @@ -367,6 +411,7 @@ void StreamCacheResource::reConstructCacheKeys() {
}
auto seq_size_per_block = seqSizePerBlock();
auto total_blocks = stream_->seqLength() / seq_size_per_block;

for (size_t i = 0; i < stream_->currentBatchSize(); ++i) {
if (!last_block_aligned_ && !batch_resource_.cache_keys[i].empty()) {
batch_resource_.cache_keys[i].pop_back();
Expand Down
24 changes: 4 additions & 20 deletions rtp_llm/cpp/engine_base/stream/StreamCacheResource.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <memory>

namespace rtp_llm {

class GenerateStream;

struct ResourceContext {
Expand Down Expand Up @@ -42,6 +41,7 @@ class StreamCacheResource {
const std::vector<int64_t>& cacheKeys(int32_t batch_id) const;
absl::StatusOr<int> initKVBlock(int token_capacity, size_t reserve_step = 0);
absl::StatusOr<int> incrKVBlock(int token_capacity, size_t reserve_step = 0);
void insertIntoCache();
void fakeInitKVBlock();
int tryReleaseKVBlock(size_t nums);
absl::Status releaseSequenceKVCache(size_t total_seq_len, size_t release_seq_len);
Expand Down Expand Up @@ -103,25 +103,9 @@ class StreamCacheResource {
stream_ = stream;
}

bool reuseCache() const;
bool enable3FS() const;

std::string debugString() const {
std::stringstream debug_string;
debug_string << "StreamCacheResource {"
<< "need_release_resource: " << need_release_resource_ << ", batch_resource: [";

for (size_t i = 0; i < batch_resource_.batchSize(); i++) {
debug_string << " [";
for (size_t j = 0; j < batch_resource_.batch_block_id[i].size(); j++) {
debug_string << batch_resource_.batch_block_id[i][j] << " ";
}
debug_string << "],";
}

debug_string << "}";
return debug_string.str();
}
bool reuseCache() const;
bool enable3FS() const;
std::string debugString() const;

private:
GenerateStream* stream_;
Expand Down
Loading
Loading